diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ce29aa0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +command +pip_list + +__pycache__/ +*.py[cod] +*$py.class +.idea/ + +stable-diffusion/CompVis +stable-diffusion/src +stable-diffusion/latent_diffusion.egg-info + + +3DPortraitGAN_pyramid/models/*.pkl +3DPortraitGAN_pyramid/models/*.ckpt +3DPortraitGAN_pyramid/models/*.pt +3DPortraitGAN_pyramid/models/*.json +3DPortraitGAN_pyramid/out +3DPortraitGAN_pyramid/smplx_models/smpl/*.pkl +3DPortraitGAN_pyramid/training-runs + +stable-dreamfusion-3DPortrait/pretrained +stable-dreamfusion-3DPortrait/output +stable-dreamfusion-3DPortrait/smplx_models +stable-dreamfusion-3DPortrait/transfer_data + +command.py +temp.py + +*.pkl +*.pth +*.pt +*.pth.tar +./data_processing/data/J_regressor_extra.npy \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/.gitignore b/3DPortraitGAN_pyramid/.gitignore new file mode 100644 index 0000000..6041f0d --- /dev/null +++ b/3DPortraitGAN_pyramid/.gitignore @@ -0,0 +1,8 @@ +models +smplx_models +transfer_data +generate_inversion_results.py +training/dataset-ref.py +training-runs +.vscode +*.out \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/README.md b/3DPortraitGAN_pyramid/README.md new file mode 100644 index 0000000..a57b8c3 --- /dev/null +++ b/3DPortraitGAN_pyramid/README.md @@ -0,0 +1,59 @@ +# 3DPortraitGAN_pyramid Training + +**Note: Upon the acceptance of our [3DPortraitGAN](https://arxiv.org/abs/2307.14770), we plan to release our 360°PHQ dataset to facilitate reproducibility of research. We encourage you to utilize our provided pre-trained models. Stay tuned for updates! ** + + + +## Training + +```shell +cd 3DPortraitGAN_pyramid + +# stage 1 +python train.py \ + --outdir=./training-runs/stage1 --cfg=full-head \ + --data=$DATASET_PATH/360PHQ-512.zip --seg_data=$DATASET_PATH/360PHQ-512-mask.zip \ + --gpus=8 --batch=32 --gamma=5.0 --cbase=18432 --cmax=144 \ + --gamma_seg=5.0 --use_torgb_raw=1 --decoder_activation="none" \ + --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --back_repeat=1 \ + --gen_pose_cond=True --gpc_reg_prob=0.7 --mirror=True --data_rebalance=False --image-snap=25 --kimg=20000 \ + --neural_rendering_resolution_initial=64 \ + --pose_loss_weight=10 --input_pose_params_reg_loss_weight=5 --input_pose_params_reg_loss_kimg=200 \ + --train_g_pose_branch=True \ + --explicitly_symmetry=True \ + --metric_pose_sample_mode=G_predict + + +# stage 2 +python train.py \ + --outdir=./training-runs/stage2 --cfg=full-head \ + --data=$DATASET_PATH/360PHQ-512.zip --seg_data=$DATASET_PATH/360PHQ-512-mask.zip \ + --gpus=8 --batch=32 --gamma=5.0 --cbase=18432 --cmax=144 \ + --gamma_seg=5.0 --use_torgb_raw=1 --decoder_activation="none" \ + --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --back_repeat=1 \ + --gen_pose_cond=True --gpc_reg_prob=0.7 --mirror=True --data_rebalance=False --image-snap=25 --kimg=20000 \ + --neural_rendering_resolution_initial=64 \ + --pose_loss_weight=10 --input_pose_params_reg_loss_weight=5 --input_pose_params_reg_loss_kimg=200 \ + --train_g_pose_branch=False \ + --explicitly_symmetry=True \ + --metric_pose_sample_mode=D_predict \ + --resume=stage1.pkl --resume_kimg=NUM_KIMGS + +# stage 3 +python train.py \ + --outdir=./training-runs/stage3 --cfg=full-head \ + --data=$DATASET_PATH/360PHQ-512.zip --seg_data=$DATASET_PATH/360PHQ-512-mask.zip \ + --gpus=8 --batch=32 --gamma=5.0 --cbase=18432 --cmax=144 \ + --gamma_seg=5.0 --use_torgb_raw=1 --decoder_activation="none" \ + --bcg_reg_prob 0.2 --triplane_depth 3 --density_noise_fade_kimg 200 --density_reg 0 --back_repeat=1 \ + --gen_pose_cond=True --gpc_reg_prob=0.7 --mirror=True --data_rebalance=False --image-snap=25 --kimg=20000 \ + --neural_rendering_resolution_initial=64 --neural_rendering_resolution_final=128 \ + --neural_rendering_resolution_fade_kimg=1000 \ + --pose_loss_weight=10 --input_pose_params_reg_loss_weight=5 --input_pose_params_reg_loss_kimg=200 \ + --train_g_pose_branch=False \ + --explicitly_symmetry=True \ + --metric_pose_sample_mode=D_predict \ + --resume=stage2.pkl --resume_kimg=NUM_KIMGS + +``` + diff --git a/3DPortraitGAN_pyramid/calc_metrics.py b/3DPortraitGAN_pyramid/calc_metrics.py new file mode 100644 index 0000000..9df0e5b --- /dev/null +++ b/3DPortraitGAN_pyramid/calc_metrics.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Calculate quality metrics for previous training run or pretrained network pickle.""" + +import os +import click +import json +import tempfile +import copy +import torch + +import dnnlib +import legacy +from metrics import metric_main +from metrics import metric_utils +from torch_utils import training_stats +from torch_utils import custom_ops +from torch_utils import misc +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- + +def subprocess_fn(rank, args, temp_dir): + dnnlib.util.Logger(should_flush=True) + + # Init torch.distributed. + if args.num_gpus > 1: + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + if rank != 0 or not args.verbose: + custom_ops.verbosity = 'none' + + # Configure torch. + device = torch.device('cuda', rank) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + conv2d_gradfix.enabled = True + + # Print network summary. + G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) + D = copy.deepcopy(args.D).eval().requires_grad_(False).to(device) if args.metric_pose_sample_mode == 'D_predict' else None + + + resample_filter = args.pose_predict_kwargs['resample_filter'] + resample_filter = torch.tensor(resample_filter, device=device).to(torch.float32) + + if rank == 0 and args.verbose: + z = torch.empty([1, G.z_dim], device=device) + c = torch.empty([1, G.c_dim], device=device) + misc.print_module_summary(G, [z, c]) + + # Calculate each metric. + for metric in args.metrics: + if rank == 0 and args.verbose: + print(f'Calculating {metric}...') + progress = metric_utils.ProgressMonitor(verbose=args.verbose) + # result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, + # num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) + + + result_dict = metric_main.calc_metric(metric=metric, + G=G, + dataset_kwargs=args.dataset_kwargs, + num_gpus=args.num_gpus, + rank=rank, + device=device, + metric_pose_sample_mode = args.metric_pose_sample_mode, + progress=progress, + identical_c_p = args.identical_c_p, + D = D, + pose_predict_kwargs = { + 'neural_rendering_resolution':args.pose_predict_kwargs['neural_rendering_resolution'], + 'blur_sigma':args.pose_predict_kwargs['blur_sigma'], + 'resample_filter':resample_filter, + 'filter_mode':args.pose_predict_kwargs['filter_mode'] + } if args.metric_pose_sample_mode == 'D_predict' else None + ) + if rank == 0: + metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) + if rank == 0 and args.verbose: + print() + + # Done. + if rank == 0 and args.verbose: + print('Exiting...') + +#---------------------------------------------------------------------------- + +def parse_comma_separated_list(s): + if isinstance(s, list): + return s + if s is None or s.lower() == 'none' or s == '': + return [] + return s.split(',') + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) +@click.option('pose_predict_kwargs', '--pose_predict_kwargs', help='Network pickle filename or URL', metavar='PATH', required=True) +@click.option('--metric_pose_sample_mode', help='Type of metric_pose_sample ', metavar='STR', type=click.Choice(['D_predict', 'G_predict']), required=False, default='G_predict') +@click.option('--identical_c_p', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') + +@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) +@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') +@click.option('--seg_data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') +@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') +@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) +@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) + +def calc_metrics(ctx, network_pkl, pose_predict_kwargs,metric_pose_sample_mode,identical_c_p ,metrics, data,seg_data, mirror, gpus, verbose): + """Calculate quality metrics for previous training run or pretrained network pickle. + + Examples: + + \b + # Previous training run: look up options automatically, save result to JSONL file. + python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ + --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl + + \b + # Pre-trained network pickle: specify dataset explicitly, print result to stdout. + python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl + + \b + Recommended metrics: + fid50k_full Frechet inception distance against the full dataset. + kid50k_full Kernel inception distance against the full dataset. + pr50k3_full Precision and recall againt the full dataset. + ppl2_wend Perceptual path length in W, endpoints, full image. + eqt50k_int Equivariance w.r.t. integer translation (EQ-T). + eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). + eqr50k Equivariance w.r.t. rotation (EQ-R). + + \b + Legacy metrics: + fid50k Frechet inception distance against 50k real images. + kid50k Kernel inception distance against 50k real images. + pr50k3 Precision and recall against 50k real images. + is50k Inception score for CIFAR-10. + """ + dnnlib.util.Logger(should_flush=True) + + # Validate arguments. + args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose,metric_pose_sample_mode=metric_pose_sample_mode) + if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): + ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) + if not args.num_gpus >= 1: + ctx.fail('--gpus must be at least 1') + + # Load network. + if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): + ctx.fail('--network must point to a file or URL') + if args.verbose: + print(f'Loading network from "{network_pkl}"...') + with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: + network_dict = legacy.load_network_pkl(f) + args.G = network_dict['G_ema'] # subclass of torch.nn.Module + args.D = network_dict['D_ema'] + + args.identical_c_p = identical_c_p + + with open(pose_predict_kwargs, 'r') as f: + args.pose_predict_kwargs = json.load(f) + + + # Initialize dataset options. + if data is not None: + #args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) + args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MaskLabeledDataset', + img_path=data, + seg_path = seg_data, + back_repeat =1, + use_labels=True, max_size=None, xflip=True, + data_rebalance=False,data_rebalance_idx_file=None) + elif network_dict['training_set_kwargs'] is not None: + args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) + else: + ctx.fail('Could not look up dataset options; please specify --data') + + # Finalize dataset options. + args.dataset_kwargs.resolution = args.G.img_resolution + args.dataset_kwargs.use_labels =True + + + # Print dataset options. + if args.verbose: + print('Dataset options:') + print(json.dumps(args.dataset_kwargs, indent=2)) + + # Locate run dir. + args.run_dir = None + if os.path.isfile(network_pkl): + pkl_dir = os.path.dirname(network_pkl) + if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): + args.run_dir = pkl_dir + + # Launch processes. + if args.verbose: + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + if args.num_gpus == 1: + subprocess_fn(rank=0, args=args, temp_dir=temp_dir) + else: + torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + calc_metrics() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/camera_utils.py b/3DPortraitGAN_pyramid/camera_utils.py new file mode 100644 index 0000000..4d4be88 --- /dev/null +++ b/3DPortraitGAN_pyramid/camera_utils.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts. +""" + +import math + +import torch +import torch.nn as nn + +from training.volumetric_rendering import math_utils + +class GaussianCameraPoseSampler: + """ + Samples pitch and yaw from a Gaussian distribution and returns a camera pose. + Camera is specified as looking at the origin. + If horizontal and vertical stddev (specified in radians) are zero, gives a + deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean. + The coordinate system is specified with y-up, z-forward, x-left. + Horizontal mean is the azimuthal angle (rotation around y axis) in radians, + vertical mean is the polar angle (angle from the y axis) in radians. + A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2. + + Example: + For a camera pose looking at the origin with the camera at position [0, 0, 1]: + cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean + v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + forward_vectors = math_utils.normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + + +class LookAtPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the + camera is specified as looking at 'lookat_position', a 3-vector. + + Example: + For a camera pose looking at the origin with the camera at position [0, 0, 1]: + cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean + v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + # forward_vectors = math_utils.normalize_vecs(-camera_origins) + forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + +class UniformCameraPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the + pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev. + + Example: + For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians: + + cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean + v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + forward_vectors = math_utils.normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + +def create_cam2world_matrix(forward_vector, origin): + """ + Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix. + Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll. + """ + + forward_vector = math_utils.normalize_vecs(forward_vector) + up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector) + + right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) + up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1)) + + rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) + rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1) + + translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) + translation_matrix[:, :3, 3] = origin + cam2world = (translation_matrix @ rotation_matrix)[:, :, :] + assert(cam2world.shape[1:] == (4, 4)) + return cam2world + + +def FOV_to_intrinsics(fov_degrees, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + + focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414)) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/dataset_tool.py b/3DPortraitGAN_pyramid/dataset_tool.py new file mode 100644 index 0000000..fd5b41c --- /dev/null +++ b/3DPortraitGAN_pyramid/dataset_tool.py @@ -0,0 +1,458 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Tool for creating ZIP/PNG based datasets.""" + +import functools +import gzip +import io +import json +import os +import pickle +import re +import sys +import tarfile +import zipfile +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +import click +import numpy as np +import PIL.Image +from tqdm import tqdm + +#---------------------------------------------------------------------------- + +def error(msg): + print('Error: ' + msg) + sys.exit(1) + +#---------------------------------------------------------------------------- + +# def parse_tuple(s: str) -> Tuple[int, int]: +# '''Parse a 'M,N' or 'MxN' integer tuple. +# +# Example: +# '4x2' returns (4,2) +# '0,1' returns (0,1) +# ''' +# if m := re.match(r'^(\d+)[x,](\d+)$', s): +# return (int(m.group(1)), int(m.group(2))) +# raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +def maybe_min(a: int, b: Optional[int]) -> int: + if b is not None: + return min(a, b) + return a + +#---------------------------------------------------------------------------- + +def file_ext(name: Union[str, Path]) -> str: + return str(name).split('.')[-1] + +#---------------------------------------------------------------------------- + +def is_image_ext(fname: Union[str, Path]) -> bool: + ext = file_ext(fname).lower() + return f'.{ext}' in PIL.Image.EXTENSION # type: ignore + +#---------------------------------------------------------------------------- + +def open_image_folder(source_dir, *, max_images: Optional[int]): + input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] + + # Load labels. + labels = {} + meta_fname = os.path.join(source_dir, 'dataset.json') + if os.path.isfile(meta_fname): + with open(meta_fname, 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + for idx, fname in enumerate(input_images): + arch_fname = os.path.relpath(fname, source_dir) + arch_fname = arch_fname.replace('\\', '/') + img = np.array(PIL.Image.open(fname)) + yield dict(img=img, label=labels.get(arch_fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_image_zip(source, *, max_images: Optional[int]): + with zipfile.ZipFile(source, mode='r') as z: + input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] + + # Load labels. + labels = {} + if 'dataset.json' in z.namelist(): + with z.open('dataset.json', 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + with zipfile.ZipFile(source, mode='r') as z: + for idx, fname in enumerate(input_images): + with z.open(fname, 'r') as file: + img = PIL.Image.open(file) # type: ignore + img = np.array(img) + yield dict(img=img, label=labels.get(fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): + import cv2 # pip install opencv-python # pylint: disable=import-error + import lmdb # pip install lmdb # pylint: disable=import-error + + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + max_idx = maybe_min(txn.stat()['entries'], max_images) + + def iterate_images(): + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.array(PIL.Image.open(io.BytesIO(value))) + yield dict(img=img, label=None) + if idx >= max_idx-1: + break + except: + print(sys.exc_info()[1]) + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_cifar10(tarball: str, *, max_images: Optional[int]): + images = [] + labels = [] + + with tarfile.open(tarball, 'r:gz') as tar: + for batch in range(1, 6): + member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') + with tar.extractfile(member) as file: + data = pickle.load(file, encoding='latin1') + images.append(data['data'].reshape(-1, 3, 32, 32)) + labels.append(data['labels']) + + images = np.concatenate(images) + labels = np.concatenate(labels) + images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC + assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_mnist(images_gz: str, *, max_images: Optional[int]): + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') + assert labels_gz != images_gz + images = [] + labels = [] + + with gzip.open(images_gz, 'rb') as f: + images = np.frombuffer(f.read(), np.uint8, offset=16) + with gzip.open(labels_gz, 'rb') as f: + labels = np.frombuffer(f.read(), np.uint8, offset=8) + + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def make_transform( + transform: Optional[str], + output_width: Optional[int], + output_height: Optional[int] +) -> Callable[[np.ndarray], Optional[np.ndarray]]: + def scale(width, height, img): + w = img.shape[1] + h = img.shape[0] + if width == w and height == h: + return img + img = PIL.Image.fromarray(img) + ww = width if width is not None else w + hh = height if height is not None else h + img = img.resize((ww, hh), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop(width, height, img): + crop = np.min(img.shape[:2]) + img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop_wide(width, height, img): + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + return None + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + img = np.array(img) + + canvas = np.zeros([width, width, 3], dtype=np.uint8) + canvas[(width - height) // 2 : (width + height) // 2, :] = img + return canvas + + if transform is None: + return functools.partial(scale, output_width, output_height) + if transform == 'center-crop': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + 'transform') + return functools.partial(center_crop, output_width, output_height) + if transform == 'center-crop-wide': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + ' transform') + return functools.partial(center_crop_wide, output_width, output_height) + assert False, 'unknown transform' + +#---------------------------------------------------------------------------- + +def open_dataset(source, *, max_images: Optional[int]): + if os.path.isdir(source): + if source.rstrip('/').endswith('_lmdb'): + return open_lmdb(source, max_images=max_images) + else: + return open_image_folder(source, max_images=max_images) + elif os.path.isfile(source): + if os.path.basename(source) == 'cifar-10-python.tar.gz': + return open_cifar10(source, max_images=max_images) + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': + return open_mnist(source, max_images=max_images) + elif file_ext(source) == 'zip': + return open_image_zip(source, max_images=max_images) + else: + assert False, 'unknown archive type' + else: + error(f'Missing input file or directory: {source}') + +#---------------------------------------------------------------------------- + +def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: + dest_ext = file_ext(dest) + + if dest_ext == 'zip': + if os.path.dirname(dest) != '': + os.makedirs(os.path.dirname(dest), exist_ok=True) + zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) + def zip_write_bytes(fname: str, data: Union[bytes, str]): + zf.writestr(fname, data) + return '', zip_write_bytes, zf.close + else: + # If the output folder already exists, check that is is + # empty. + # + # Note: creating the output directory is not strictly + # necessary as folder_write_bytes() also mkdirs, but it's better + # to give an error message earlier in case the dest folder + # somehow cannot be created. + if os.path.isdir(dest) and len(os.listdir(dest)) != 0: + error('--dest folder must be empty') + os.makedirs(dest, exist_ok=True) + + def folder_write_bytes(fname: str, data: Union[bytes, str]): + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fout: + if isinstance(data, str): + data = data.encode('utf8') + fout.write(data) + return dest, folder_write_bytes, lambda: None + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH') +@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH') +@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) +@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) +@click.option('--resolution', type=int) +def convert_dataset( + ctx: click.Context, + source: str, + dest: str, + max_images: Optional[int], + transform: Optional[str], + resolution: Optional[Tuple[int, int]] +): + """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. + + The input dataset format is guessed from the --source argument: + + \b + --source *_lmdb/ Load LSUN dataset + --source cifar-10-python.tar.gz Load CIFAR-10 dataset + --source train-images-idx3-ubyte.gz Load MNIST dataset + --source path/ Recursively load all images from path/ + --source dataset.zip Recursively load all images from dataset.zip + + Specifying the output format and path: + + \b + --dest /path/to/dir Save output files under /path/to/dir + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip + + The output dataset format can be either an image folder or an uncompressed zip archive. + Zip archives makes it easier to move datasets around file servers and clusters, and may + offer better training performance on network file systems. + + Images within the dataset archive will be stored as uncompressed PNG. + Uncompressed PNGs can be efficiently decoded in the training loop. + + Class labels are stored in a file called 'dataset.json' that is stored at the + dataset root folder. This file has the following structure: + + \b + { + "labels": [ + ["00000/img00000000.png",6], + ["00000/img00000001.png",9], + ... repeated for every image in the dataset + ["00049/img00049999.png",1] + ] + } + + If the 'dataset.json' file cannot be found, the dataset is interpreted as + not containing class labels. + + Image scale/crop and resolution requirements: + + Output images must be square-shaped and they must all have the same power-of-two + dimensions. + + To scale arbitrary input image size to a specific width and height, use the + --resolution option. Output resolution will be either the original + input resolution (if resolution was not specified) or the one specified with + --resolution option. + + Use the --transform=center-crop or --transform=center-crop-wide options to apply a + center crop transform on the input image. These options should be used with the + --resolution option. For example: + + \b + python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ + --transform=center-crop-wide --resolution=512x384 + """ + + PIL.Image.init() # type: ignore + + if dest == '': + ctx.fail('--dest output filename or directory must not be an empty string') + + num_files, input_iter = open_dataset(source, max_images=max_images) + archive_root_dir, save_bytes, close_dest = open_dest(dest) + + resolution = (resolution,resolution) + transform_image = make_transform(transform, *resolution) + + dataset_attrs = None + + labels = [] + for idx, image in tqdm(enumerate(input_iter), total=num_files): + idx_str = f'{idx:010d}' + archive_fname = f'{idx // 1000:010d}/{idx_str}.png' + + # Apply crop and resize. + img = transform_image(image['img']) + + # Transform may drop images. + if img is None: + continue + + # Error check to require uniform image attributes across + # the whole dataset. + channels = img.shape[2] if img.ndim == 3 else 1 + cur_image_attrs = { + 'width': img.shape[1], + 'height': img.shape[0], + 'channels': channels + } + if dataset_attrs is None: + dataset_attrs = cur_image_attrs + width = dataset_attrs['width'] + height = dataset_attrs['height'] + if width != height: + error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') + if dataset_attrs['channels'] not in [1, 3, 4]: + error('Input images must be stored as RGB or grayscale') + if width != 2 ** int(np.floor(np.log2(width))): + error('Image width/height after scale and crop are required to be power-of-two') + elif dataset_attrs != cur_image_attrs: + err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object + error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) + + # Save the image as an uncompressed PNG. + img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB', 4: 'RGBA'}[channels]) + if channels == 4: img = img.convert('RGB') + image_bits = io.BytesIO() + img.save(image_bits, format='png', compress_level=0, optimize=False) + save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) + labels.append([archive_fname, image['label']] if image['label'] is not None else None) + + metadata = { + 'labels': labels if all(x is not None for x in labels) else None + } + save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) + close_dest() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_dataset() # pylint: disable=no-value-for-parameter diff --git a/3DPortraitGAN_pyramid/dnnlib/__init__.py b/3DPortraitGAN_pyramid/dnnlib/__init__.py new file mode 100644 index 0000000..dd91ed1 --- /dev/null +++ b/3DPortraitGAN_pyramid/dnnlib/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/3DPortraitGAN_pyramid/dnnlib/util.py b/3DPortraitGAN_pyramid/dnnlib/util.py new file mode 100644 index 0000000..80b67c4 --- /dev/null +++ b/3DPortraitGAN_pyramid/dnnlib/util.py @@ -0,0 +1,493 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/3DPortraitGAN_pyramid/environment.yml b/3DPortraitGAN_pyramid/environment.yml new file mode 100644 index 0000000..bd2ed07 --- /dev/null +++ b/3DPortraitGAN_pyramid/environment.yml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +name: 3DPortraitGAN +channels: + - pytorch + - nvidia +dependencies: + - python=3.9.16 + - pip + - numpy>=1.20 + - click>=8.0 + - pillow=8.3.1 + - scipy=1.7.1 + - requests=2.26.0 + - tqdm=4.62.2 + - ninja=1.10.2 + - matplotlib=3.4.2 + - imageio=2.9.0 + - pip: + - imgui==1.3.0 + - glfw==2.2.0 + - pyopengl==3.1.5 + - imageio-ffmpeg==0.4.3 + - pyspng + - psutil + - mrcfile + - tensorboard + - smplx==0.1.28 + - trimesh==3.21.4 + - chumpy==0.70 \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gen_quality_improve_data_from_triplane.py b/3DPortraitGAN_pyramid/gen_quality_improve_data_from_triplane.py new file mode 100644 index 0000000..7410cb5 --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_quality_improve_data_from_triplane.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile +import json +import legacy + +from camera_utils import LookAtPoseSampler,FOV_to_intrinsics +from torch_utils import misc +import glob +import PIL + +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, trigrid=None,ws=None, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + image_mode='image', gen_shapes=False, device=torch.device('cuda'), large_pose= False, + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + num_keyframes = 1 + + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(ws), 1) + + p = torch.zeros([len(ws), 6], device=device) + + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264', **video_kwargs) + + + all_poses = [] + + if large_pose: + image_row = [] + + + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + if large_pose: + # 0 - 2pi + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + (frame_idx / w_frames) * 2 * np.pi, + np.pi / 2, + camera_lookat_point, radius=2.7, device=device) + else: + pitch_range = 0.25 + yaw_range = 0.35 + cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + yaw_range * np.sin(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + np.pi/2 -0.05 + pitch_range * np.cos(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + camera_lookat_point, radius=2.7, device=device) + all_poses.append(cam2world_pose.squeeze().cpu().numpy()) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + + img = G.render_planes(ws=w.unsqueeze(0), planes=trigrid, c=c[0:1], noise_mode='const', neural_rendering_resolution=512,chunk = 4096)[image_mode][0] + + if image_mode == 'image_depth': + img = -img + img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 + + imgs.append(img) + if large_pose and frame_idx % int(num_keyframes * w_frames//8) == 0: + image_row.append((img.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)) + + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + all_poses = np.stack(all_poses) + + if large_pose: + import PIL + image_row = torch.cat(image_row, 1).cpu().numpy() + PIL.Image.fromarray(image_row.astype(np.uint8)).save(mp4.replace('.mp4', '_large_pose.png')) + + + if gen_shapes: + print(all_poses.shape) + with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: + np.save(f, all_poses) + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--data_dir', help='Network pickle filename', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image_depth', 'image_raw']), required=False, metavar='STR', default='image_raw', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) + +def generate_images( + network_pkl: str, + data_dir: str, + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], +): + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + + + print("Reloading Modules!") + from training.neural_renderer import TriPlaneGenerator + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=False) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + G.set_batch_size(1) + + res_dir = data_dir + + outdir = os.path.join(res_dir, 'data') + os.makedirs(outdir, exist_ok=True) + if os.path.exists(os.path.join(outdir, 'camera_info.json')): + print('Camera info already exists, skipping generation.') + return + + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + ckpt_path = os.path.join(res_dir, 'checkpoints/df.pth') + if not os.path.exists(ckpt_path): + print('No checkpoints found, skipping generation.') + return + + print('Loading checkpoints from "%s"...' % ckpt_path) + ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['model'] + trigrid = { + 8: ckpt['trigrids_8'].to(device), + 16: ckpt['trigrids_16'].to(device), + 32: ckpt['trigrids_32'].to(device), + 64: ckpt['trigrids_64'].to(device), + 128: ckpt['trigrids_128'].to(device), + 256: ckpt['trigrids_256'].to(device), + 512: ckpt['trigrids_512'].to(device), + } + ws = ckpt['ws'].to(device) + + intrinsics = FOV_to_intrinsics(12.447863, device=device) + + cam_pivot = torch.tensor([0, 0.0649, 0], device=device) + cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) + + camera_info = {} + + # sample angle_y from -pi/2 to pi/2 uniformly, 30 steps + # sample angle_p from pi/2 - 15/180*pi to pi/2 + 30/180*pi + sample_idx = 0 + step = 8 + angle_ys = np.linspace(0, np.pi * 2, step) + angle_ps = [np.pi / 2 - 30 / 180 * np.pi, np.pi / 2, np.pi / 2 + 30 / 180 * np.pi] + # for angle_y, angle_p in [(-np.pi / 2, np.pi / 2), (0, np.pi / 2), (np.pi / 2, np.pi / 2)]: + for i in tqdm(range(step - 1)): + + angle_y = angle_ys[i] + for j in range(3): + angle_p = angle_ps[j] + + angle_p = angle_p + np.random.uniform(-np.pi / 180 * 5, np.pi / 180 * 5) + + cam2world_pose = LookAtPoseSampler.sample(angle_y, angle_p, cam_pivot, radius=cam_radius, device=device) + + camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + img = G.render_planes(ws=ws, planes=trigrid, c=camera_params, noise_mode='const', + neural_rendering_resolution=512, chunk=4096)[image_mode] + + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{sample_idx:04d}.png') + + camera_info[f'{sample_idx:04d}.png'] = camera_params.cpu().numpy().tolist() + + sample_idx += 1 + + with open(os.path.join(outdir, 'camera_info.json'), 'w') as f: + json.dump(camera_info, f) + + + + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gen_samples.py b/3DPortraitGAN_pyramid/gen_samples.py new file mode 100644 index 0000000..f0ba73e --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_samples.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate images and shapes using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import numpy as np +import PIL.Image +import torch +from tqdm import tqdm +import mrcfile + + +import legacy +from camera_utils import LookAtPoseSampler, FOV_to_intrinsics +from torch_utils import misc +from training.smpl_triplane import TriPlaneGenerator + + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: + '''Parse a floating point 2-vector of syntax 'a,b'. + + Example: + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + parts = s.split(',') + if len(parts) == 2: + return (float(parts[0]), float(parts[1])) + raise ValueError(f'cannot parse 2-vector {s}') + +#---------------------------------------------------------------------------- + +def make_transform(translate: Tuple[float,float], angle: float): + m = np.eye(3) + s = np.sin(angle/360.0*np.pi*2) + c = np.cos(angle/360.0*np.pi*2) + m[0][0] = c + m[0][1] = s + m[0][2] = translate[0] + m[1][0] = -s + m[1][1] = c + m[1][2] = translate[1] + return m + +#---------------------------------------------------------------------------- + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds_num', type=int, help='List of random seeds (e.g., \'0,1,4-6\')', required=False) +@click.option('--seeds', type=parse_range, help='List of random seeds', required=False, metavar='LIST', default=[]) + +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +@click.option('--shapes', help='Export shapes as .mrc files viewable in ChimeraX', type=bool, required=False, metavar='BOOL', default=False, show_default=True) +@click.option('--images', help='Export shapes as .mrc files viewable in ChimeraX', type=bool, required=False, metavar='BOOL', default=True, show_default=True) + +@click.option('--shape-res', help='', type=int, required=False, metavar='int', default=512, show_default=True) +@click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=12.447863, show_default=True) +@click.option('--shape-format', help='Shape Format', type=click.Choice(['.mrc', '.ply']), default='.mrc') +@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) +def generate_images( + network_pkl: str, + seeds_num: int, + seeds: List[int], + truncation_psi: float, + truncation_cutoff: int, + outdir: str, + shapes: bool, + images: bool, + shape_res: int, + fov_deg: float, + shape_format: str, + class_idx: Optional[int], + reload_modules: bool, +): + """Generate images using pretrained network pickle. + + Examples: + + \b + # Generate an image using pre-trained FFHQ model. + python gen_samples.py --outdir=output --trunc=0.7 --seeds=0-5 --shapes=True\\ + --network=ffhq-rebalanced-128.pkl + """ + import random + if seeds == []: + seeds = random.sample(range(1000000), seeds_num) + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * 4) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * 4) + + G.rendering_kwargs['ray_start'] = 2.35 + + + # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code + if reload_modules: + print("Reloading Modules!") + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=True) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + os.makedirs(outdir, exist_ok=True) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, torch.tensor([0, 0.0649, 0], device=device), radius=2.7, device=device) + intrinsics = FOV_to_intrinsics(fov_deg, device=device) + + cond_p = torch.zeros([1, 6], device=device) + cam_pivot = torch.tensor([0, 0.0649, 0], device=device) + cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) + conditioning_cam2world_pose = LookAtPoseSampler.sample(np.pi / 2, np.pi / 2, cam_pivot, + radius=cam_radius, device=device) + conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + # Generate images. + for seed_idx, seed in enumerate(seeds): + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + + if images: + imgs = [] + #angle_p = -0.2 + for angle_y, angle_p in [(-np.pi / 2, np.pi / 2), (0, np.pi / 2), (np.pi / 2, np.pi / 2)]: + + cam2world_pose = LookAtPoseSampler.sample(angle_y, angle_p, cam_pivot, radius=cam_radius, device=device) + + camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + + + ws = G.mapping(z, conditioning_params,cond_p, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + img = G.synthesis(ws, c=camera_params, noise_mode='const', apply_def=False, pose_params=None)['image'] + + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + imgs.append(img) + + img = torch.cat(imgs, dim=2) + + PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') + + if shapes: + # extract a shape.mrc with marching cubes. You can view the .mrc file using ChimeraX from UCSF. + max_batch=1000000 + + samples, voxel_origin, voxel_size = create_samples(N=shape_res, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp'] * 0.8)#.reshape(1, -1, 3) + samples = samples.to(z.device) + sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=z.device) + transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=z.device) + transformed_ray_directions_expanded[..., -1] = -1 + + head = 0 + with tqdm(total = samples.shape[1]) as pbar: + with torch.no_grad(): + while head < samples.shape[1]: + torch.manual_seed(0) + sigma = G.sample(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], z, conditioning_params, cond_p,truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, noise_mode='const')['sigma'] + sigmas[:, head:head+max_batch] = sigma + head += max_batch + pbar.update(max_batch) + + sigmas = sigmas.reshape((shape_res, shape_res, shape_res)).cpu().numpy() + sigmas = np.flip(sigmas, 0) + + # Trim the border of the extracted cube + pad = int(30 * shape_res / 256) + pad_value = -1000 + sigmas[:pad] = pad_value + sigmas[-pad:] = pad_value + sigmas[:, :pad] = pad_value + sigmas[:, -pad:] = pad_value + sigmas[:, :, :pad] = pad_value + sigmas[:, :, -pad:] = pad_value + + if shape_format == '.ply': + from shape_utils import convert_sdf_samples_to_ply + convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'seed{seed:04d}.ply'), level=15) + elif shape_format == '.mrc': # output mrc + with mrcfile.new_mmap(os.path.join(outdir, f'seed{seed:04d}.mrc'), overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc: + mrc.data[:] = sigmas + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gen_samples_with_pose_prediction.py b/3DPortraitGAN_pyramid/gen_samples_with_pose_prediction.py new file mode 100644 index 0000000..0a6e492 --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_samples_with_pose_prediction.py @@ -0,0 +1,298 @@ + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import numpy as np +import PIL.Image +import torch +from tqdm import tqdm +import mrcfile + + +import legacy +from camera_utils import LookAtPoseSampler, FOV_to_intrinsics +from torch_utils import misc +from training.smpl_triplane import TriPlaneGenerator +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: + '''Parse a floating point 2-vector of syntax 'a,b'. + + Example: + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + parts = s.split(',') + if len(parts) == 2: + return (float(parts[0]), float(parts[1])) + raise ValueError(f'cannot parse 2-vector {s}') + +#---------------------------------------------------------------------------- + +def make_transform(translate: Tuple[float,float], angle: float): + m = np.eye(3) + s = np.sin(angle/360.0*np.pi*2) + c = np.cos(angle/360.0*np.pi*2) + m[0][0] = c + m[0][1] = s + m[0][2] = translate[0] + m[1][0] = -s + m[1][1] = c + m[1][2] = translate[1] + return m + +#---------------------------------------------------------------------------- + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing + +# return pose +def run_D_pose_prediction(img, c, blur_sigma=0,D = None): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + pose,_ = D.predict_pose( img, c) + return pose + +def get_pose_params(real_img,real_c,D = None,neural_rendering_resolution = None,blur_sigma = None,resample_filter = None, filter_mode = None): + + + + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + if True: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw} + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw} + + predicted_real_pose = run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma, D = D) + return predicted_real_pose + + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--test_data', help='Real data dir', required=True) +@click.option('--outdir', help='output dir', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds', required=True) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=12.447863, show_default=True) +@click.option('--shape-format', help='Shape Format', type=click.Choice(['.mrc', '.ply']), default='.mrc') +@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) +def generate_images( + network_pkl: str, + test_data: str, + outdir: str, + seeds: List[int], + truncation_psi: float, + truncation_cutoff: int, + fov_deg: float, + shape_format: str, + class_idx: Optional[int], + reload_modules: bool, +): + """Generate images using pretrained network pickle. + + Examples: + + \b + # Generate an image using pre-trained FFHQ model. + python gen_samples.py --outdir=output --trunc=0.7 --seeds=0-5 --shapes=True\\ + --network=ffhq-rebalanced-128.pkl + """ + os.makedirs(outdir, exist_ok=True) + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + resume_data = legacy.load_network_pkl(f) + print('resume_data',resume_data.keys()) + G = resume_data['G_ema'].to(device) # type: ignore + D = resume_data['D_ema'].to(device) # type: ignore + + G.set_batch_size(1) + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * 2) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * 2) + + G.rendering_kwargs['ray_start'] = 2.35 + + + # get pose prediction kwargs + import json + pose_prediction_kwargs_path = network_pkl.replace('.pkl','-pose_predict_kwargs.json') # network-snapshot-001400.pkl + print('Loading pose_prediction_kwargs from "%s"...' % pose_prediction_kwargs_path) + with open(pose_prediction_kwargs_path, 'r') as f: + pose_predict_kwargs = json.load(f) + + + # read images + import glob + + real_image_paths = glob.glob(os.path.join(test_data, 'aligned_images/*')) + + + path = os.path.join(test_data, 'result.json') + with open(path, 'r') as f: + labels = json.load(f) + + poses = [] + cameras = [] + names = [] + from PIL import Image + intrinsics = np.reshape( + np.array([6.510416666666667, + 0.0, + 0.5, + 0.0, + 6.510416666666667, + 0.5, + 0.0, + 0.0, + 1.0]), (1, 9) + ) + + with torch.no_grad(): + for real_image_path in real_image_paths: + image = Image.open(real_image_path).convert('RGB') + image = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + image = np.array(image) + image = image.transpose(2, 0, 1) + image = torch.tensor(image, device=device) + image = image.to(device).to(torch.float32) / 127.5 - 1 + image = image.unsqueeze(0) + image_id = os.path.basename(real_image_path).split('.')[0] + + c = labels[os.path.basename(real_image_path)]['camera_pose'] + c = np.reshape(np.array(c),(1,16)) + c = np.concatenate((c, intrinsics), axis=1) + + c = torch.tensor(c, device=device).to(torch.float32) + resample_filter = pose_predict_kwargs['resample_filter'] + resample_filter = torch.tensor(resample_filter, device=device).to(torch.float32) + + p = get_pose_params(real_img=image, + real_c=c, + D=D, + neural_rendering_resolution=pose_predict_kwargs['neural_rendering_resolution'], + blur_sigma=pose_predict_kwargs['blur_sigma'], + resample_filter=resample_filter, + filter_mode=pose_predict_kwargs['filter_mode']) + + poses.append(p) + cameras.append(c) + names.append(image_id) + + # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code + if True: + print("Reloading Modules!") + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=False) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + G.set_batch_size(1) + + + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2, np.pi / 2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 # if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + cond_c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + cond_p = torch.zeros([1, 6], device=device) + + + + for seed_idx, seed in enumerate(seeds): + + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + + ws = G.mapping(z, cond_c, cond_p, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) + for pose_idx in range(len(poses)): + p = poses[pose_idx] + c = cameras[pose_idx] + name = names[pose_idx] + + img = G.synthesis(ws, c=c, noise_mode='const', apply_def=True, pose_params=p)['image'] + + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() + + real_image_path = real_image_paths[pose_idx] + image = Image.open(real_image_path).convert('RGB') + image = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + img = np.concatenate((np.array(image), img), axis=1) + + PIL.Image.fromarray(img, 'RGB').save(f'{outdir}/{seed}_{name}.png') + + + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/gen_videos.py b/3DPortraitGAN_pyramid/gen_videos.py new file mode 100644 index 0000000..f0aa839 --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_videos.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile + +import legacy + +from camera_utils import LookAtPoseSampler +from torch_utils import misc +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), large_pose= False, + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + if num_keyframes is None: + if len(seeds) % (grid_w*grid_h) != 0: + raise ValueError('Number of input seeds must be divisible by grid W*H') + num_keyframes = len(seeds) // (grid_w*grid_h) + + all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) + for idx in range(num_keyframes*grid_h*grid_w): + all_seeds[idx] = seeds[idx % len(seeds)] + + if shuffle_seed is not None: + rng = np.random.RandomState(seed=shuffle_seed) + rng.shuffle(all_seeds) + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 # if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(zs), 1) + + p = torch.zeros([len(zs), 6], device=device) + + ws = G.mapping(z=zs, c=c, p=p,truncation_psi=psi, truncation_cutoff=truncation_cutoff) + _ = G.synthesis(ws[:1], c[:1]) # warm up + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + max_batch = 10000000 + voxel_resolution = 512 + video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264', **video_kwargs) + + if gen_shapes: + outdir = 'interpolation_{}_{}/'.format(all_seeds[0], all_seeds[1]) + os.makedirs(outdir, exist_ok=True) + all_poses = [] + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + if large_pose: + # 0 - 2pi + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + (frame_idx / w_frames) * 2 * np.pi, + np.pi / 2, + camera_lookat_point, radius=2.7, device=device) + else: + pitch_range = 0.25 + yaw_range = 0.35 + cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + yaw_range * np.sin(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + np.pi/2 -0.05 + pitch_range * np.cos(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + camera_lookat_point, radius=2.7, device=device) + all_poses.append(cam2world_pose.squeeze().cpu().numpy()) + focal_length = 6.5104166 if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + + entangle = 'camera' + if entangle == 'conditioning': + c_forward = torch.cat([LookAtPoseSampler.sample(np.pi/2, + np.pi/2, + camera_lookat_point, + radius=2.7, device=device).reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) + img = G.synthesis(ws=w_c, c=c_forward, noise_mode='const')[image_mode][0] + elif entangle == 'camera': + img = G.synthesis(ws=w.unsqueeze(0), c=c[0:1], noise_mode='const')[image_mode][0] + elif entangle == 'both': + w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) + img = G.synthesis(ws=w_c, c=c[0:1], noise_mode='const')[image_mode][0] + + if image_mode == 'image_depth': + img = -img + img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 + + imgs.append(img) + + if gen_shapes: + # generate shapes + print('Generating shape for frame %d / %d ...' % (frame_idx, num_keyframes * w_frames)) + + samples, voxel_origin, voxel_size = create_samples(N=voxel_resolution, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp']) + samples = samples.to(device) + sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device) + transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device) + transformed_ray_directions_expanded[..., -1] = -1 + + head = 0 + with tqdm(total = samples.shape[1]) as pbar: + with torch.no_grad(): + while head < samples.shape[1]: + torch.manual_seed(0) + sigma = G.sample_mixed(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], w.unsqueeze(0), truncation_psi=psi, noise_mode='const')['sigma'] + sigmas[:, head:head+max_batch] = sigma + head += max_batch + pbar.update(max_batch) + + sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy() + sigmas = np.flip(sigmas, 0) + + pad = int(30 * voxel_resolution / 256) + pad_top = int(38 * voxel_resolution / 256) + sigmas[:pad] = 0 + sigmas[-pad:] = 0 + sigmas[:, :pad] = 0 + sigmas[:, -pad_top:] = 0 + sigmas[:, :, :pad] = 0 + sigmas[:, :, -pad:] = 0 + + output_ply = True + if output_ply: + from shape_utils import convert_sdf_samples_to_ply + convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'{frame_idx:04d}_shape.ply'), level=10) + else: # output mrc + with mrcfile.new_mmap(outdir + f'{frame_idx:04d}_shape.mrc', overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc: + mrc.data[:] = sigmas + + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + all_poses = np.stack(all_poses) + + if gen_shapes: + print(all_poses.shape) + with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: + np.save(f, all_poses) + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--outdir', help='Output directory', type=str, required=True, metavar='DIR') +@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) +@click.option('--cfg', help='Config', type=click.Choice(['FFHQ', 'AFHQ', 'Shapenet']), required=False, metavar='STR', default='FFHQ', show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image', 'image_depth', 'image_raw']), required=False, metavar='STR', default='image', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) +@click.option('--shapes', type=bool, help='Gen shapes for shape interpolation', default=False, show_default=True) +@click.option('--interpolate', type=bool, help='Interpolate between seeds', default=True, show_default=True) +@click.option('--large_pose', type=bool, help='Gen shapes for shape interpolation', default=False, show_default=True) + +def generate_images( + network_pkl: str, + seeds: List[int], + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + outdir: str, + reload_modules: bool, + cfg: str, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], + shapes: bool, + interpolate: bool, + large_pose: bool, +): + """Render a latent vector interpolation video. + + Examples: + + \b + # Render a 4x2 grid of interpolations for seeds 0 through 31. + python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + Animation length and seed keyframes: + + The animation length is either determined based on the --seeds value or explicitly + specified using the --num-keyframes option. + + When num keyframes is specified with --num-keyframes, the output video length + will be 'num_keyframes*w_frames' frames. + + If --num-keyframes is not specified, the number of seeds given with + --seeds must be divisible by grid size W*H (--grid). In this case the + output video length will be '# seeds/(w*h)*w_frames' frames. + """ + + if not os.path.exists(outdir): + os.makedirs(outdir, exist_ok=True) + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int(G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + G.set_batch_size(1) + + + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + if interpolate: + raise NotImplementedError + output = os.path.join(outdir, 'interpolation.mp4') + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, truncation_cutoff=truncation_cutoff, cfg=cfg, image_mode=image_mode, gen_shapes=shapes) + else: + for seed in seeds: + if large_pose: + output = os.path.join(outdir, f'{seed}_large_pose.mp4') + else: + output = os.path.join(outdir, f'{seed}_small_pose.mp4') + seeds_ = [seed] + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds_, shuffle_seed=shuffle_seed, psi=truncation_psi, truncation_cutoff=truncation_cutoff, cfg=cfg, image_mode=image_mode,large_pose = large_pose) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gen_videos_shapes_from_optimized_triplane.py b/3DPortraitGAN_pyramid/gen_videos_shapes_from_optimized_triplane.py new file mode 100644 index 0000000..31b5463 --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_videos_shapes_from_optimized_triplane.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile + +import legacy + +from camera_utils import LookAtPoseSampler +from torch_utils import misc +import glob +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, trigrid=None,ws=None, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + image_mode='image', gen_shapes=False, device=torch.device('cuda'), large_pose= False, + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + num_keyframes = 1 + + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(ws), 1) + + p = torch.zeros([len(ws), 6], device=device) + + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264', **video_kwargs) + + + all_poses = [] + + if large_pose: + image_row_8 = [] + image_row_4 = [] + + + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + if large_pose: + # 0 - 2pi + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + (frame_idx / w_frames) * 2 * np.pi, + np.pi / 2, + camera_lookat_point, radius=2.7, device=device) + else: + pitch_range = 0.25 + yaw_range = 0.35 + cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + yaw_range * np.sin(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + np.pi/2 -0.05 + pitch_range * np.cos(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + camera_lookat_point, radius=2.7, device=device) + all_poses.append(cam2world_pose.squeeze().cpu().numpy()) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + + img = G.render_planes(ws=w.unsqueeze(0), planes=trigrid, c=c[0:1], noise_mode='const', neural_rendering_resolution=512,chunk = 4096)[image_mode][0] + + if image_mode == 'image_depth': + img = -img + img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 + + imgs.append(img) + if large_pose and frame_idx % int(num_keyframes * w_frames//8) == 0: + image_row_8.append((img.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)) + + if large_pose and frame_idx % int(num_keyframes * w_frames//4) == 0: + image_row_4.append((img.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)) + + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + all_poses = np.stack(all_poses) + + if large_pose: + import PIL + image_row_8 = torch.cat(image_row_8, 1).cpu().numpy() + PIL.Image.fromarray(image_row_8.astype(np.uint8)).save(mp4.replace('.mp4', '_final_8.png')) + + image_row_4 = torch.cat(image_row_4, 1).cpu().numpy() + PIL.Image.fromarray(image_row_4.astype(np.uint8)).save(mp4.replace('.mp4', '_final_4.png')) + + + if gen_shapes: + print(all_poses.shape) + with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: + np.save(f, all_poses) + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--data_dir', help='Network pickle filename', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image_depth', 'image_raw']), required=False, metavar='STR', default='image_raw', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) + +def generate_images( + network_pkl: str, + data_dir: str, + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], +): + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + + + print("Reloading Modules!") + from training.smpl_triplane import TriPlaneGenerator + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=True) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + G.set_batch_size(1) + res_dir = data_dir + + outdir = os.path.join(res_dir, 'results_final') + os.makedirs(outdir, exist_ok=True) + if not os.path.exists(os.path.join(res_dir, 'log/ckpt')): + print('WARNING: No checkpoints found in "%s"!' % os.path.join(res_dir, 'log/ckpt')) + return + + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + ckpt_path = glob.glob(os.path.join(res_dir, 'log/ckpt/*')) + ckpt_path = sorted(ckpt_path) + ckpt_path = ckpt_path[-1] + if not os.path.exists(ckpt_path): + print('WARNING: No checkpoints found in "%s"!' % ckpt_path) + return + print('Loading checkpoints from "%s"...' % ckpt_path) + ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['model'] + trigrid = { + 8: ckpt['trigrids_8'].to(device).detach(), + 16: ckpt['trigrids_16'].to(device).detach(), + 32: ckpt['trigrids_32'].to(device).detach(), + 64: ckpt['trigrids_64'].to(device).detach(), + 128: ckpt['trigrids_128'].to(device).detach(), + 256: ckpt['trigrids_256'].to(device).detach(), + 512: ckpt['trigrids_512'].to(device).detach(), + } + ws = ckpt['ws'].to(device) + + output = os.path.join(outdir, f'large_pose_final.mp4') + print('Generating video "%s"...' % output) + if not os.path.exists(output): + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=240, + trigrid=trigrid, ws=ws, shuffle_seed=shuffle_seed, psi=truncation_psi, + truncation_cutoff=truncation_cutoff, image_mode=image_mode, large_pose=True) + output = os.path.join(outdir, f'small_pose_final.mp4') + print('Generating video "%s"...' % output) + if not os.path.exists(output): + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=120, + trigrid=trigrid, ws=ws, shuffle_seed=shuffle_seed, psi=truncation_psi, + truncation_cutoff=truncation_cutoff, image_mode=image_mode, large_pose=False) + + print('Generating shapes...') + + shape_res = 512 + max_batch = 1000000 + shape_format = '.mrc' + + if shape_format == '.ply': + from shape_utils import convert_sdf_samples_to_ply + shape_path = os.path.join(outdir, f'shape.ply') + elif shape_format == '.mrc': # output mrc + shape_path = os.path.join(outdir, f'shape.mrc') + + if not os.path.exists(shape_path): + + samples, voxel_origin, voxel_size = create_samples(N=shape_res, voxel_origin=[0, 0, 0], + cube_length=0.9) # .reshape(1, -1, 3) + samples = samples.to(device) + sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device) + transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device) + transformed_ray_directions_expanded[..., -1] = -1 + + head = 0 + with tqdm(total=samples.shape[1]) as pbar: + with torch.no_grad(): + while head < samples.shape[1]: + torch.manual_seed(0) + sigma = G.sample_trigrid(samples[:, head:head + max_batch], + transformed_ray_directions_expanded[:, :samples.shape[1] - head], + planes=trigrid, truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, noise_mode='const')['sigma'] + sigmas[:, head:head + max_batch] = sigma + head += max_batch + pbar.update(max_batch) + + sigmas = sigmas.reshape((shape_res, shape_res, shape_res)).cpu().numpy() + sigmas = np.flip(sigmas, 0) + + # Trim the border of the extracted cube + pad = int(30 * shape_res / 256) + pad_value = -1000 + sigmas[:pad] = pad_value + sigmas[-pad:] = pad_value + sigmas[:, :pad] = pad_value + sigmas[:, -pad:] = pad_value + sigmas[:, :, :pad] = pad_value + sigmas[:, :, -pad:] = pad_value + + if shape_format == '.ply': + from shape_utils import convert_sdf_samples_to_ply + convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, + os.path.join(outdir, f'shape.ply'), level=15) + elif shape_format == '.mrc': # output mrc + with mrcfile.new_mmap(os.path.join(outdir, f'shape.mrc'), overwrite=True, shape=sigmas.shape, + mrc_mode=2) as mrc: + mrc.data[:] = sigmas + + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gen_videos_shapes_from_triplane.py b/3DPortraitGAN_pyramid/gen_videos_shapes_from_triplane.py new file mode 100644 index 0000000..2d436a3 --- /dev/null +++ b/3DPortraitGAN_pyramid/gen_videos_shapes_from_triplane.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile + +import legacy + +from camera_utils import LookAtPoseSampler +from torch_utils import misc +import glob +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, trigrid=None,ws=None, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + image_mode='image', gen_shapes=False, device=torch.device('cuda'), large_pose= False, + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + num_keyframes = 1 + + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(ws), 1) + + p = torch.zeros([len(ws), 6], device=device) + + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264', **video_kwargs) + + + all_poses = [] + + if large_pose: + image_row = [] + + + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + if large_pose: + # 0 - 2pi + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + (frame_idx / w_frames) * 2 * np.pi, + np.pi / 2, + camera_lookat_point, radius=2.7, device=device) + else: + pitch_range = 0.25 + yaw_range = 0.35 + cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + yaw_range * np.sin(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + np.pi/2 -0.05 + pitch_range * np.cos(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + camera_lookat_point, radius=2.7, device=device) + all_poses.append(cam2world_pose.squeeze().cpu().numpy()) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + + img = G.render_planes(ws=w.unsqueeze(0), planes=trigrid, c=c[0:1], noise_mode='const', neural_rendering_resolution=512,chunk = 4096)[image_mode][0] + + if image_mode == 'image_depth': + img = -img + img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 + + imgs.append(img) + if large_pose and frame_idx % int(num_keyframes * w_frames//8) == 0: + image_row.append((img.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)) + + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + all_poses = np.stack(all_poses) + + if large_pose: + import PIL + image_row = torch.cat(image_row, 1).cpu().numpy() + PIL.Image.fromarray(image_row.astype(np.uint8)).save(mp4.replace('.mp4', '_large_pose.png')) + + + if gen_shapes: + print(all_poses.shape) + with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: + np.save(f, all_poses) + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--data_dir', help='Network pickle filename', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image_depth', 'image_raw']), required=False, metavar='STR', default='image_raw', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) + +def generate_images( + network_pkl: str, + data_dir: str, + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], +): + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + + + print("Reloading Modules!") + from training.smpl_triplane import TriPlaneGenerator + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=True) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + G.set_batch_size(1) + + + for res_dir in glob.glob(data_dir + '/*'): + outdir = os.path.join(res_dir, 'results') + os.makedirs(outdir, exist_ok=True) + + + + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + ckpt_path = os.path.join(res_dir, 'checkpoints/df.pth') + if not os.path.exists(ckpt_path): + continue + print('Loading checkpoints from "%s"...' % ckpt_path) + ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['model'] + trigrid = { + 8:ckpt['trigrids_8'].to(device), + 16:ckpt['trigrids_16'].to(device), + 32:ckpt['trigrids_32'].to(device), + 64:ckpt['trigrids_64'].to(device), + 128:ckpt['trigrids_128'].to(device), + 256:ckpt['trigrids_256'].to(device), + } + ws = ckpt['ws'].to(device) + + output = os.path.join(outdir, f'large_pose.mp4') + print('Generating video "%s"...' % output) + if not os.path.exists(output): + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, + trigrid=trigrid,ws=ws, shuffle_seed=shuffle_seed, psi=truncation_psi, + truncation_cutoff=truncation_cutoff, image_mode=image_mode, large_pose=True) + output = os.path.join(outdir, f'small_pose.mp4') + print('Generating video "%s"...' % output) + if not os.path.exists(output): + gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, + trigrid=trigrid,ws=ws, shuffle_seed=shuffle_seed, psi=truncation_psi, + truncation_cutoff=truncation_cutoff, image_mode=image_mode, large_pose=False) + + print('Generating shapes...') + + shape_res = 512 + max_batch = 1000000 + shape_format = '.mrc' + + if shape_format == '.ply': + from shape_utils import convert_sdf_samples_to_ply + shape_path =os.path.join(outdir, f'shape.ply') + elif shape_format == '.mrc': # output mrc + shape_path = os.path.join(outdir, f'shape.mrc') + + if not os.path.exists(shape_path): + + samples, voxel_origin, voxel_size = create_samples(N=shape_res, voxel_origin=[0, 0, 0], + cube_length=0.9) # .reshape(1, -1, 3) + samples = samples.to(device) + sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device) + transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device) + transformed_ray_directions_expanded[..., -1] = -1 + + head = 0 + with tqdm(total=samples.shape[1]) as pbar: + with torch.no_grad(): + while head < samples.shape[1]: + torch.manual_seed(0) + sigma = G.sample_trigrid(samples[:, head:head + max_batch], + transformed_ray_directions_expanded[:, :samples.shape[1] - head], planes = trigrid, truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, noise_mode='const')['sigma'] + sigmas[:, head:head + max_batch] = sigma + head += max_batch + pbar.update(max_batch) + + sigmas = sigmas.reshape((shape_res, shape_res, shape_res)).cpu().numpy() + sigmas = np.flip(sigmas, 0) + + # Trim the border of the extracted cube + pad = int(30 * shape_res / 256) + pad_value = -1000 + sigmas[:pad] = pad_value + sigmas[-pad:] = pad_value + sigmas[:, :pad] = pad_value + sigmas[:, -pad:] = pad_value + sigmas[:, :, :pad] = pad_value + sigmas[:, :, -pad:] = pad_value + + if shape_format == '.ply': + from shape_utils import convert_sdf_samples_to_ply + convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, + os.path.join(outdir, f'shape.ply'), level=15) + elif shape_format == '.mrc': # output mrc + with mrcfile.new_mmap(os.path.join(outdir, f'shape.mrc'), overwrite=True, shape=sigmas.shape, + mrc_mode=2) as mrc: + mrc.data[:] = sigmas + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/get_decoder_ckpt.py b/3DPortraitGAN_pyramid/get_decoder_ckpt.py new file mode 100644 index 0000000..f9057c3 --- /dev/null +++ b/3DPortraitGAN_pyramid/get_decoder_ckpt.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile + +import legacy + +from camera_utils import LookAtPoseSampler +from torch_utils import misc +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), large_pose= False, + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + if num_keyframes is None: + if len(seeds) % (grid_w*grid_h) != 0: + raise ValueError('Number of input seeds must be divisible by grid W*H') + num_keyframes = len(seeds) // (grid_w*grid_h) + + all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) + for idx in range(num_keyframes*grid_h*grid_w): + all_seeds[idx] = seeds[idx % len(seeds)] + + if shuffle_seed is not None: + rng = np.random.RandomState(seed=shuffle_seed) + rng.shuffle(all_seeds) + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 # if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(zs), 1) + + p = torch.zeros([len(zs), 6], device=device) + + ws = G.mapping(z=zs, c=c, p=p,truncation_psi=psi, truncation_cutoff=truncation_cutoff) + _ = G.synthesis(ws[:1], c[:1]) # warm up + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + max_batch = 10000000 + voxel_resolution = 512 + video_out = imageio.get_writer(mp4, mode='I', fps=30, codec='libx264', **video_kwargs) + + if gen_shapes: + outdir = 'interpolation_{}_{}/'.format(all_seeds[0], all_seeds[1]) + os.makedirs(outdir, exist_ok=True) + all_poses = [] + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + if large_pose: + # 0 - 2pi + cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + (frame_idx / w_frames) * 2 * np.pi, + np.pi / 2, + camera_lookat_point, radius=2.7, device=device) + else: + pitch_range = 0.25 + yaw_range = 0.35 + cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + yaw_range * np.sin(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + np.pi/2 -0.05 + pitch_range * np.cos(2 * np.pi * frame_idx / (num_keyframes * w_frames)), + camera_lookat_point, radius=2.7, device=device) + all_poses.append(cam2world_pose.squeeze().cpu().numpy()) + focal_length = 6.5104166 if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + + entangle = 'camera' + if entangle == 'conditioning': + c_forward = torch.cat([LookAtPoseSampler.sample(np.pi/2, + np.pi/2, + camera_lookat_point, + radius=2.7, device=device).reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) + img = G.synthesis(ws=w_c, c=c_forward, noise_mode='const')[image_mode][0] + elif entangle == 'camera': + img = G.synthesis(ws=w.unsqueeze(0), c=c[0:1], noise_mode='const')[image_mode][0] + elif entangle == 'both': + w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) + img = G.synthesis(ws=w_c, c=c[0:1], noise_mode='const')[image_mode][0] + + if image_mode == 'image_depth': + img = -img + img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 + + imgs.append(img) + + if gen_shapes: + # generate shapes + print('Generating shape for frame %d / %d ...' % (frame_idx, num_keyframes * w_frames)) + + samples, voxel_origin, voxel_size = create_samples(N=voxel_resolution, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp']) + samples = samples.to(device) + sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device) + transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device) + transformed_ray_directions_expanded[..., -1] = -1 + + head = 0 + with tqdm(total = samples.shape[1]) as pbar: + with torch.no_grad(): + while head < samples.shape[1]: + torch.manual_seed(0) + sigma = G.sample_mixed(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], w.unsqueeze(0), truncation_psi=psi, noise_mode='const')['sigma'] + sigmas[:, head:head+max_batch] = sigma + head += max_batch + pbar.update(max_batch) + + sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy() + sigmas = np.flip(sigmas, 0) + + pad = int(30 * voxel_resolution / 256) + pad_top = int(38 * voxel_resolution / 256) + sigmas[:pad] = 0 + sigmas[-pad:] = 0 + sigmas[:, :pad] = 0 + sigmas[:, -pad_top:] = 0 + sigmas[:, :, :pad] = 0 + sigmas[:, :, -pad:] = 0 + + output_ply = True + if output_ply: + from shape_utils import convert_sdf_samples_to_ply + convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'{frame_idx:04d}_shape.ply'), level=10) + else: # output mrc + with mrcfile.new_mmap(outdir + f'{frame_idx:04d}_shape.mrc', overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc: + mrc.data[:] = sigmas + + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + all_poses = np.stack(all_poses) + + if gen_shapes: + print(all_poses.shape) + with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: + np.save(f, all_poses) + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +def extract_decoder_ckpt( + network_pkl: str, +): + """Render a latent vector interpolation video. + + Examples: + + \b + # Render a 4x2 grid of interpolations for seeds 0 through 31. + python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + Animation length and seed keyframes: + + The animation length is either determined based on the --seeds value or explicitly + specified using the --num-keyframes option. + + When num keyframes is specified with --num-keyframes, the output video length + will be 'num_keyframes*w_frames' frames. + + If --num-keyframes is not specified, the number of seeds given with + --seeds must be divisible by grid size W*H (--grid). In this case the + output video length will be '# seeds/(w*h)*w_frames' frames. + """ + + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + # print the number of parameters of superresolution + print('Number of parameters of superresolution: %d' % sum(p.numel() for p in G.superresolution.parameters())) + print('Number of parameters of backbone: %d' % sum(p.numel() for p in G.backbone.parameters())) + # exit() # sum(p.numel() for p in model.parameters()) + # save the ckpt of G.decoder + decoder_ckpt = G.decoder.state_dict() + decoder_state_dict = {} + for k, v in decoder_ckpt.items(): + decoder_state_dict['decoder.'+k] = v + ckpt = decoder_state_dict + + torgb_ckpt = G.torgb.state_dict() + rgb_state_dict = {} + for k, v in torgb_ckpt.items(): + rgb_state_dict['torgb.'+k] = v + ckpt.update(rgb_state_dict) + + pose_branch_ckpt = G.pose_branch.state_dict() + pose_branch_state_dict = {} + for k, v in pose_branch_ckpt.items(): + pose_branch_state_dict['pose_branch.'+k] = v + ckpt.update(pose_branch_state_dict) + + # save parameters of G.decoder + + + params = {'z_dim': G.z_dim, + 'c_dim': G.c_dim, + 'w_dim': G.w_dim, + 'img_resolution': G.img_resolution, + 'img_channels': G.img_channels, + 'rendering_kwargs': G.rendering_kwargs, + 'batch_size': G.batch_size, + 'thickness': 0.25} + + print(params) + ckpt = {'params': params, # parameters of G.decoder + 'state_dict': ckpt} + print('save decoder ckpt to ./models/'+os.path.basename(network_pkl).replace('.pkl', '_decoder.ckpt')) + torch.save(ckpt, './models/'+os.path.basename(network_pkl).replace('.pkl', '_decoder.ckpt')) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + extract_decoder_ckpt() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/gui_utils/__init__.py b/3DPortraitGAN_pyramid/gui_utils/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/3DPortraitGAN_pyramid/gui_utils/gl_utils.py b/3DPortraitGAN_pyramid/gui_utils/gl_utils.py new file mode 100644 index 0000000..1312f02 --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/gl_utils.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import functools +import contextlib +import numpy as np +import OpenGL.GL as gl +import OpenGL.GL.ARB.texture_float +import dnnlib + +#---------------------------------------------------------------------------- + +def init_egl(): + assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL. + import OpenGL.EGL as egl + import ctypes + + # Initialize EGL. + display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + assert display != egl.EGL_NO_DISPLAY + major = ctypes.c_int32() + minor = ctypes.c_int32() + ok = egl.eglInitialize(display, major, minor) + assert ok + assert major.value * 10 + minor.value >= 14 + + # Choose config. + config_attribs = [ + egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, + egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, + egl.EGL_NONE + ] + configs = (ctypes.c_int32 * 1)() + num_configs = ctypes.c_int32() + ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs) + assert ok + assert num_configs.value == 1 + config = configs[0] + + # Create dummy pbuffer surface. + surface_attribs = [ + egl.EGL_WIDTH, 1, + egl.EGL_HEIGHT, 1, + egl.EGL_NONE + ] + surface = egl.eglCreatePbufferSurface(display, config, surface_attribs) + assert surface != egl.EGL_NO_SURFACE + + # Setup GL context. + ok = egl.eglBindAPI(egl.EGL_OPENGL_API) + assert ok + context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None) + assert context != egl.EGL_NO_CONTEXT + ok = egl.eglMakeCurrent(display, surface, surface, context) + assert ok + +#---------------------------------------------------------------------------- + +_texture_formats = { + ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8), + ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8), + ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8), + ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8), + ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB), + ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB), + ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F), + ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F), +} + +def get_texture_format(dtype, channels): + return _texture_formats[(np.dtype(dtype).name, int(channels))] + +#---------------------------------------------------------------------------- + +def prepare_texture_data(image): + image = np.asarray(image) + if image.ndim == 2: + image = image[:, :, np.newaxis] + if image.dtype.name == 'float64': + image = image.astype('float32') + return image + +#---------------------------------------------------------------------------- + +def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + align = np.broadcast_to(np.asarray(align, dtype='float32'), [2]) + image = prepare_texture_data(image) + height, width, channels = image.shape + size = zoom * [width, height] + pos = pos - size * align + if rint: + pos = np.rint(pos) + fmt = get_texture_format(image.dtype, channels) + + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glRasterPos2f(pos[0], pos[1]) + gl.glPixelZoom(zoom[0], -zoom[1]) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glDrawPixels(width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + gl.glPopAttrib() + +#---------------------------------------------------------------------------- + +def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + dtype = np.dtype(dtype) + fmt = get_texture_format(dtype, channels) + image = np.empty([height, width, channels], dtype=dtype) + + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + return np.flipud(image) + +#---------------------------------------------------------------------------- + +class Texture: + def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True): + self.gl_id = None + self.bilinear = bilinear + self.mipmap = mipmap + + # Determine size and dtype. + if image is not None: + image = prepare_texture_data(image) + self.height, self.width, self.channels = image.shape + self.dtype = image.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 3 + self.dtype = np.dtype(dtype) if dtype is not None else np.uint8 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype) + + # Create texture object. + self.gl_id = gl.glGenTextures(1) + with self.bind(): + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST) + self.update(image) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteTextures([self.gl_id]) + self.gl_id = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D) + gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id) + yield + gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id) + + def update(self, image): + if image is not None: + image = prepare_texture_data(image) + assert self.is_compatible(image=image) + with self.bind(): + fmt = get_texture_format(self.dtype, self.channels) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image) + if self.mipmap: + gl.glGenerateMipmap(gl.GL_TEXTURE_2D) + gl.glPopClientAttrib() + + def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0): + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + size = zoom * [self.width, self.height] + with self.bind(): + gl.glPushAttrib(gl.GL_ENABLE_BIT) + gl.glEnable(gl.GL_TEXTURE_2D) + draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding) + gl.glPopAttrib() + + def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements + if image is not None: + if image.ndim != 3: + return False + ih, iw, ic = image.shape + if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype): + return False + if width is not None and self.width != width: + return False + if height is not None and self.height != height: + return False + if channels is not None and self.channels != channels: + return False + if dtype is not None and self.dtype != dtype: + return False + return True + +#---------------------------------------------------------------------------- + +class Framebuffer: + def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0): + self.texture = texture + self.gl_id = None + self.gl_color = None + self.gl_depth_stencil = None + self.msaa = msaa + + # Determine size and dtype. + if texture is not None: + assert isinstance(self.texture, Texture) + self.width = texture.width + self.height = texture.height + self.channels = texture.channels + self.dtype = texture.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 4 + self.dtype = np.dtype(dtype) if dtype is not None else np.float32 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert width is None or width == self.width + assert height is None or height == self.height + assert channels is None or channels == self.channels + assert dtype is None or dtype == self.dtype + + # Create framebuffer object. + self.gl_id = gl.glGenFramebuffers(1) + with self.bind(): + + # Setup color buffer. + if self.texture is not None: + assert self.msaa == 0 + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0) + else: + fmt = get_texture_format(self.dtype, self.channels) + self.gl_color = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color) + + # Setup depth/stencil buffer. + self.gl_depth_stencil = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteFramebuffers([self.gl_id]) + self.gl_id = None + if self.gl_color is not None: + gl.glDeleteRenderbuffers(1, [self.gl_color]) + self.gl_color = None + if self.gl_depth_stencil is not None: + gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil]) + self.gl_depth_stencil = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) + prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id) + if self.width is not None and self.height is not None: + gl.glViewport(0, 0, self.width, self.height) + yield + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo) + + def blit(self, dst=None): + assert dst is None or isinstance(dst, Framebuffer) + with self.bind(): + gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo) + gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST) + +#---------------------------------------------------------------------------- + +def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1): + assert vertices.ndim == 2 and vertices.shape[1] == 2 + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) + color = np.broadcast_to(np.asarray(color, dtype='float32'), [3]) + alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1) + + gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT) + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glPushMatrix() + + gl.glEnableClientState(gl.GL_VERTEX_ARRAY) + gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY) + gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTranslate(pos[0], pos[1], 0) + gl.glScale(size[0], size[1], 1) + gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha) + gl.glDrawArrays(mode, 0, vertices.shape[0]) + + gl.glPopMatrix() + gl.glPopAttrib() + gl.glPopClientAttrib() + +#---------------------------------------------------------------------------- + +def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0): + assert pos2 is None or size is None + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None + size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32') + pos = pos - size * align + if rint: + pos = np.rint(pos) + rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2]) + rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5) + if np.min(rounding) == 0: + rounding *= 0 + vertices = _setup_rect(float(rounding[0]), float(rounding[1])) + draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_rect(rx, ry): + t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64) + s = 1 - np.sin(t); c = 1 - np.cos(t) + x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx] + y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry] + v = np.stack([x, y], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- + +def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1): + hole = np.broadcast_to(np.asarray(hole, dtype='float32'), []) + vertices = _setup_circle(float(hole)) + draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_circle(hole): + t = np.linspace(0, np.pi * 2, 128) + s = np.sin(t); c = np.cos(t) + v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/gui_utils/glfw_window.py b/3DPortraitGAN_pyramid/gui_utils/glfw_window.py new file mode 100644 index 0000000..aeb96e8 --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/glfw_window.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import time +import glfw +import OpenGL.GL as gl +from . import gl_utils + +#---------------------------------------------------------------------------- + +class GlfwWindow: # pylint: disable=too-many-public-methods + def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): + self._glfw_window = None + self._drawing_frame = False + self._frame_start_time = None + self._frame_delta = 0 + self._fps_limit = None + self._vsync = None + self._skip_frames = 0 + self._deferred_show = deferred_show + self._close_on_esc = close_on_esc + self._esc_pressed = False + self._drag_and_drop_paths = None + self._capture_next_frame = False + self._captured_frame = None + + # Create window. + glfw.init() + glfw.window_hint(glfw.VISIBLE, False) + self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) + self._attach_glfw_callbacks() + self.make_context_current() + + # Adjust window. + self.set_vsync(False) + self.set_window_size(window_width, window_height) + if not self._deferred_show: + glfw.show_window(self._glfw_window) + + def close(self): + if self._drawing_frame: + self.end_frame() + if self._glfw_window is not None: + glfw.destroy_window(self._glfw_window) + self._glfw_window = None + #glfw.terminate() # Commented out to play it nice with other glfw clients. + + def __del__(self): + try: + self.close() + except: + pass + + @property + def window_width(self): + return self.content_width + + @property + def window_height(self): + return self.content_height + self.title_bar_height + + @property + def content_width(self): + width, _height = glfw.get_window_size(self._glfw_window) + return width + + @property + def content_height(self): + _width, height = glfw.get_window_size(self._glfw_window) + return height + + @property + def title_bar_height(self): + _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) + return top + + @property + def monitor_width(self): + _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return width + + @property + def monitor_height(self): + _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return height + + @property + def frame_delta(self): + return self._frame_delta + + def set_title(self, title): + glfw.set_window_title(self._glfw_window, title) + + def set_window_size(self, width, height): + width = min(width, self.monitor_width) + height = min(height, self.monitor_height) + glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) + if width == self.monitor_width and height == self.monitor_height: + self.maximize() + + def set_content_size(self, width, height): + self.set_window_size(width, height + self.title_bar_height) + + def maximize(self): + glfw.maximize_window(self._glfw_window) + + def set_position(self, x, y): + glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) + + def center(self): + self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) + + def set_vsync(self, vsync): + vsync = bool(vsync) + if vsync != self._vsync: + glfw.swap_interval(1 if vsync else 0) + self._vsync = vsync + + def set_fps_limit(self, fps_limit): + self._fps_limit = int(fps_limit) + + def should_close(self): + return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) + + def skip_frame(self): + self.skip_frames(1) + + def skip_frames(self, num): # Do not update window for the next N frames. + self._skip_frames = max(self._skip_frames, int(num)) + + def is_skipping_frames(self): + return self._skip_frames > 0 + + def capture_next_frame(self): + self._capture_next_frame = True + + def pop_captured_frame(self): + frame = self._captured_frame + self._captured_frame = None + return frame + + def pop_drag_and_drop_paths(self): + paths = self._drag_and_drop_paths + self._drag_and_drop_paths = None + return paths + + def draw_frame(self): # To be overridden by subclass. + self.begin_frame() + # Rendering code goes here. + self.end_frame() + + def make_context_current(self): + if self._glfw_window is not None: + glfw.make_context_current(self._glfw_window) + + def begin_frame(self): + # End previous frame. + if self._drawing_frame: + self.end_frame() + + # Apply FPS limit. + if self._frame_start_time is not None and self._fps_limit is not None: + delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit + if delay > 0: + time.sleep(delay) + cur_time = time.perf_counter() + if self._frame_start_time is not None: + self._frame_delta = cur_time - self._frame_start_time + self._frame_start_time = cur_time + + # Process events. + glfw.poll_events() + + # Begin frame. + self._drawing_frame = True + self.make_context_current() + + # Initialize GL state. + gl.glViewport(0, 0, self.content_width, self.content_height) + gl.glMatrixMode(gl.GL_PROJECTION) + gl.glLoadIdentity() + gl.glTranslate(-1, 1, 0) + gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glLoadIdentity() + gl.glEnable(gl.GL_BLEND) + gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. + + # Clear. + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + + def end_frame(self): + assert self._drawing_frame + self._drawing_frame = False + + # Skip frames if requested. + if self._skip_frames > 0: + self._skip_frames -= 1 + return + + # Capture frame if requested. + if self._capture_next_frame: + self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) + self._capture_next_frame = False + + # Update window. + if self._deferred_show: + glfw.show_window(self._glfw_window) + self._deferred_show = False + glfw.swap_buffers(self._glfw_window) + + def _attach_glfw_callbacks(self): + glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) + glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) + + def _glfw_key_callback(self, _window, key, _scancode, action, _mods): + if action == glfw.PRESS and key == glfw.KEY_ESCAPE: + self._esc_pressed = True + + def _glfw_drop_callback(self, _window, paths): + self._drag_and_drop_paths = paths + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/gui_utils/imgui_utils.py b/3DPortraitGAN_pyramid/gui_utils/imgui_utils.py new file mode 100644 index 0000000..05a8357 --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/imgui_utils.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import contextlib +import imgui + +#---------------------------------------------------------------------------- + +def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): + s = imgui.get_style() + s.window_padding = [spacing, spacing] + s.item_spacing = [spacing, spacing] + s.item_inner_spacing = [spacing, spacing] + s.columns_min_spacing = spacing + s.indent_spacing = indent + s.scrollbar_size = scrollbar + s.frame_padding = [4, 3] + s.window_border_size = 1 + s.child_border_size = 1 + s.popup_border_size = 1 + s.frame_border_size = 1 + s.window_rounding = 0 + s.child_rounding = 0 + s.popup_rounding = 3 + s.frame_rounding = 3 + s.scrollbar_rounding = 3 + s.grab_rounding = 3 + + getattr(imgui, f'style_colors_{color_scheme}')(s) + c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] + s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def grayed_out(cond=True): + if cond: + s = imgui.get_style() + text = s.colors[imgui.COLOR_TEXT_DISABLED] + grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] + back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + imgui.push_style_color(imgui.COLOR_TEXT, *text) + imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_BUTTON, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_HEADER, *back) + imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) + yield + imgui.pop_style_color(14) + else: + yield + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def item_width(width=None): + if width is not None: + imgui.push_item_width(width) + yield + imgui.pop_item_width() + else: + yield + +#---------------------------------------------------------------------------- + +def scoped_by_object_id(method): + def decorator(self, *args, **kwargs): + imgui.push_id(str(id(self))) + res = method(self, *args, **kwargs) + imgui.pop_id() + return res + return decorator + +#---------------------------------------------------------------------------- + +def button(label, width=0, enabled=True): + with grayed_out(not enabled): + clicked = imgui.button(label, width=width) + clicked = clicked and enabled + return clicked + +#---------------------------------------------------------------------------- + +def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): + expanded = False + if show: + if default: + flags |= imgui.TREE_NODE_DEFAULT_OPEN + if not enabled: + flags |= imgui.TREE_NODE_LEAF + with grayed_out(not enabled): + expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) + expanded = expanded and enabled + return expanded, visible + +#---------------------------------------------------------------------------- + +def popup_button(label, width=0, enabled=True): + if button(label, width, enabled): + imgui.open_popup(label) + opened = imgui.begin_popup(label) + return opened + +#---------------------------------------------------------------------------- + +def input_text(label, value, buffer_length, flags, width=None, help_text=''): + old_value = value + color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) + if value == '': + color[-1] *= 0.5 + with item_width(width): + imgui.push_style_color(imgui.COLOR_TEXT, *color) + value = value if value != '' else help_text + changed, value = imgui.input_text(label, value, buffer_length, flags) + value = value if value != help_text else '' + imgui.pop_style_color(1) + if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: + changed = (value != old_value) + return changed, value + +#---------------------------------------------------------------------------- + +def drag_previous_control(enabled=True): + dragging = False + dx = 0 + dy = 0 + if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): + if enabled: + dragging = True + dx, dy = imgui.get_mouse_drag_delta() + imgui.reset_mouse_drag_delta() + imgui.end_drag_drop_source() + return dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_button(label, width=0, enabled=True): + clicked = button(label, width=width, enabled=enabled) + dragging, dx, dy = drag_previous_control(enabled=enabled) + return clicked, dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_hidden_window(label, x, y, width, height, enabled=True): + imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) + imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) + imgui.set_next_window_position(x, y) + imgui.set_next_window_size(width, height) + imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) + dragging, dx, dy = drag_previous_control(enabled=enabled) + imgui.end() + imgui.pop_style_color(2) + return dragging, dx, dy + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/gui_utils/imgui_window.py b/3DPortraitGAN_pyramid/gui_utils/imgui_window.py new file mode 100644 index 0000000..0e1a638 --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/imgui_window.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import imgui +import imgui.integrations.glfw + +from . import glfw_window +from . import imgui_utils +from . import text_utils + +#---------------------------------------------------------------------------- + +class ImguiWindow(glfw_window.GlfwWindow): + def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): + if font is None: + font = text_utils.get_default_font() + font_sizes = {int(size) for size in font_sizes} + super().__init__(title=title, **glfw_kwargs) + + # Init fields. + self._imgui_context = None + self._imgui_renderer = None + self._imgui_fonts = None + self._cur_font_size = max(font_sizes) + + # Delete leftover imgui.ini to avoid unexpected behavior. + if os.path.isfile('imgui.ini'): + os.remove('imgui.ini') + + # Init ImGui. + self._imgui_context = imgui.create_context() + self._imgui_renderer = _GlfwRenderer(self._glfw_window) + self._attach_glfw_callbacks() + imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. + imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). + self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} + self._imgui_renderer.refresh_font_texture() + + def close(self): + self.make_context_current() + self._imgui_fonts = None + if self._imgui_renderer is not None: + self._imgui_renderer.shutdown() + self._imgui_renderer = None + if self._imgui_context is not None: + #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. + self._imgui_context = None + super().close() + + def _glfw_key_callback(self, *args): + super()._glfw_key_callback(*args) + self._imgui_renderer.keyboard_callback(*args) + + @property + def font_size(self): + return self._cur_font_size + + @property + def spacing(self): + return round(self._cur_font_size * 0.4) + + def set_font_size(self, target): # Applied on next frame. + self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] + + def begin_frame(self): + # Begin glfw frame. + super().begin_frame() + + # Process imgui events. + self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 + if self.content_width > 0 and self.content_height > 0: + self._imgui_renderer.process_inputs() + + # Begin imgui frame. + imgui.new_frame() + imgui.push_font(self._imgui_fonts[self._cur_font_size]) + imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) + + def end_frame(self): + imgui.pop_font() + imgui.render() + imgui.end_frame() + self._imgui_renderer.render(imgui.get_draw_data()) + super().end_frame() + +#---------------------------------------------------------------------------- +# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. + +class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mouse_wheel_multiplier = 1 + + def scroll_callback(self, window, x_offset, y_offset): + self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/gui_utils/text_utils.py b/3DPortraitGAN_pyramid/gui_utils/text_utils.py new file mode 100644 index 0000000..e64a34d --- /dev/null +++ b/3DPortraitGAN_pyramid/gui_utils/text_utils.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import functools +from typing import Optional + +import dnnlib +import numpy as np +import PIL.Image +import PIL.ImageFont +import scipy.ndimage + +from . import gl_utils + +#---------------------------------------------------------------------------- + +def get_default_font(): + url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular + return dnnlib.util.open_url(url, return_filename=True) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=None) +def get_pil_font(font=None, size=32): + if font is None: + font = get_default_font() + return PIL.ImageFont.truetype(font=font, size=size) + +#---------------------------------------------------------------------------- + +def get_array(string, *, dropshadow_radius: int=None, **kwargs): + if dropshadow_radius is not None: + offset_x = int(np.ceil(dropshadow_radius*2/3)) + offset_y = int(np.ceil(dropshadow_radius*2/3)) + return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + return _get_array_priv(string, **kwargs) + +@functools.lru_cache(maxsize=10000) +def _get_array_priv( + string: str, *, + size: int = 32, + max_width: Optional[int]=None, + max_height: Optional[int]=None, + min_size=10, + shrink_coef=0.8, + dropshadow_radius: int=None, + offset_x: int=None, + offset_y: int=None, + **kwargs +): + cur_size = size + array = None + while True: + if dropshadow_radius is not None: + # separate implementation for dropshadow text rendering + array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + array = _get_array_impl(string, size=cur_size, **kwargs) + height, width, _ = array.shape + if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): + break + cur_size = max(int(cur_size * shrink_coef), min_size) + return array + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + if outline > 0: + mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, outline) + alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): + assert (offset_x > 0) and (offset_y > 0) + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + + mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, radius) + alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def get_texture(string, bilinear=True, mipmap=True, **kwargs): + return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/latent_optimization_inversion.py b/3DPortraitGAN_pyramid/latent_optimization_inversion.py new file mode 100644 index 0000000..3c3165f --- /dev/null +++ b/3DPortraitGAN_pyramid/latent_optimization_inversion.py @@ -0,0 +1,242 @@ +import glob + +import numpy as np +import dnnlib +import legacy +from proj.projector import w_projector, w_plus_projector +from proj.configs import global_config, hyperparameters +from PIL import Image +import torch +import json +import os +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing + + +# ---------------------------------------------------------------------------- +class Space_Regulizer: + def __init__(self, original_G, lpips_net): + self.original_G = original_G + self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha + self.lpips_loss = lpips_net + + def get_morphed_w_code(self, new_w_code, fixed_w): + interpolation_direction = new_w_code - fixed_w + interpolation_direction_norm = torch.norm(interpolation_direction, p=2) + direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm + result_w = fixed_w + direction_to_move + self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code + + return result_w + + def get_image_from_ws(self, w_codes, G): + return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes]) + + def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False): + loss = 0.0 + + z_samples = np.random.randn(num_of_sampled_latents, self.original_G.z_dim) + w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None, + truncation_psi=0.5) + territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples] + + for w_code in territory_indicator_ws: + new_img = new_G.synthesis(w_code, noise_mode='none', force_fp32=True) + with torch.no_grad(): + old_img = self.original_G.synthesis(w_code, noise_mode='none', force_fp32=True) + + if hyperparameters.regulizer_l2_lambda > 0: + l2_loss_val = l2_loss.l2_loss(old_img, new_img) + + loss += l2_loss_val * hyperparameters.regulizer_l2_lambda + + if hyperparameters.regulizer_lpips_lambda > 0: + loss_lpips = self.lpips_loss(old_img, new_img) + loss_lpips = torch.mean(torch.squeeze(loss_lpips)) + + loss += loss_lpips * hyperparameters.regulizer_lpips_lambda + + return loss / len(territory_indicator_ws) + + def space_regulizer_loss(self, new_G, w_batch, use_wandb): + ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb) + return ret_val + + + + + +def l2_loss(real_images, generated_images): + l2_criterion = torch.nn.MSELoss(reduction='mean') + loss = l2_criterion(real_images, generated_images) + return loss + + +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +def run_D_pose_prediction(img, c, blur_sigma=0, D=None): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + pose, _ = D.predict_pose(img, c) + return pose + + +def get_pose_params(real_img, real_seg, real_c, D=None, neural_rendering_resolution=None, blur_sigma=None, + resample_filter=None, filter_mode=None): + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + if True: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, + 'image_mask': real_img_tmp_image_mask} + + predicted_real_pose = run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma, D=D) + return predicted_real_pose + + +if __name__ == '__main__': + # input_dir + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default='input') + parser.add_argument('--model_pkl', type=str, default='input') + parser.add_argument('--pose_prediction_kwargs_path', type=str, default='input') + input_dir = parser.parse_args().input_dir + model_pkl = parser.parse_args().model_pkl + pose_prediction_kwargs_path = parser.parse_args().pose_prediction_kwargs_path + # ---------------------------------------------------------------------------- + sampling_multiplier = 2.0 + + print('Loading networks from "%s"...' % model_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(model_pkl) as f: + resume_data = legacy.load_network_pkl(f) + print('resume_data', resume_data.keys()) + G = resume_data['G_ema'].to(device) # type: ignore + D = resume_data['D_ema'].to(device) # type: ignore + + G.set_batch_size(1) + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + print('Loading pose_prediction_kwargs from "%s"...' % pose_prediction_kwargs_path) + with open(pose_prediction_kwargs_path, 'r') as f: + pose_predict_kwargs = json.load(f) + + + + + + camera_path = os.path.join(input_dir, 'result.json') + print('Loading camera pose from "%s"...' % camera_path) + with open(camera_path, 'r') as f: + camera_poses = json.load(f) + + print('Loading images from "%s"...' % input_dir) + image_base_dir = os.path.join(input_dir, 'aligned_images') + mask_base_path = os.path.join(input_dir, 'mask') + + images = glob.glob(os.path.join(image_base_dir, '*')) + + print('images', images) + for image_path in images: + image_name = os.path.basename(image_path) + mask_path = os.path.join(mask_base_path, image_name) + print('projecting image: "%s"' % image_path) + image = Image.open(image_path).convert('RGB') + mask = Image.open(mask_path) + # image_name = os.path.basename(paths_config.input_data_path) + camera_pose = camera_poses[image_name] + cam2world_pose = torch.tensor(camera_pose['camera_pose'], device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + with torch.no_grad(): + image_p = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + image_p = np.array(image_p) + image_p = image_p.transpose(2, 0, 1) + image_p = torch.tensor(image_p, device=device) + image_p = image_p.to(device).to(torch.float32) / 127.5 - 1 + image_p = image_p.unsqueeze(0) + + mask_p = np.array(mask)[:, :, None] + mask_p = mask_p.transpose(2, 0, 1) + mask_p = torch.tensor(mask_p, device=device) + mask_p = mask_p.to(device).to(torch.float32) / 255.0 + mask_p = mask_p.unsqueeze(0) + + resample_filter = pose_predict_kwargs['resample_filter'] + resample_filter = torch.tensor(resample_filter, device=device).to(torch.float32) + + p = get_pose_params(real_img=image_p, + real_seg=mask_p, + real_c=c, + D=D, + neural_rendering_resolution=pose_predict_kwargs['neural_rendering_resolution'], + blur_sigma=pose_predict_kwargs['blur_sigma'], + resample_filter=resample_filter, + filter_mode=pose_predict_kwargs['filter_mode']) + + # ---------------------------------------------------------------------------- + image_name = image_name[:-4] + # coach = SingleIDCoach(None, False, c, p) + # coach.train(image=image, image_name=image_name[:-4]) + w_path_dir = f'{input_dir}/inversion' + os.makedirs(w_path_dir, exist_ok=True) + use_ball_holder = True + # for fname, image in tqdm(self.data_loader): + # image_name = fname[0] + + embedding_dir = f'{w_path_dir}/{image_name}' + os.makedirs(embedding_dir, exist_ok=True) + image.save(f'{embedding_dir}/original.png') + w_pivot = None + # if hyperparameters.use_last_w_pivots: + # w_pivot = self.load_inversions(w_path_dir, image_name) + # elif not hyperparameters.use_last_w_pivots or w_pivot is None: + # w_pivot = self.calc_inversions(image, image_name) + # image = torch.tensor(image, device=device) + if os.path.exists(f'{embedding_dir}/0.pt'): + w_pivot = torch.load(f'{embedding_dir}/0.pt').to(global_config.device) + else: + image = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + image = np.array(image) + image = image.transpose(2, 0, 1) + image = torch.tensor(image, device=device) + image = image.to(device).to(torch.float32) / 127.5 - 1 + image = image.unsqueeze(0) + id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + # id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + w_pivot = w_projector.project(G, c, p, embedding_dir, id_image, device=torch.device('cuda'), w_avg_samples=600, + num_steps=500, + w_name=image_name) + # w_pivot = w_pivot.detach().clone().to(global_config.device) + w_pivot = w_pivot.to(global_config.device) + torch.save(w_pivot, f'{embedding_dir}/inversion.pt') + diff --git a/3DPortraitGAN_pyramid/legacy.py b/3DPortraitGAN_pyramid/legacy.py new file mode 100644 index 0000000..f30944a --- /dev/null +++ b/3DPortraitGAN_pyramid/legacy.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Converting legacy network pickle into the new format.""" + +import click +import pickle +import re +import copy +import numpy as np +import torch +import dnnlib +from torch_utils import misc + +#---------------------------------------------------------------------------- + +def load_network_pkl(f, force_fp16=False): + data = _LegacyUnpickler(f).load() + + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs) + data = dict(G=G, D=D, G_ema=G_ema) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + assert isinstance(data['G'], torch.nn.Module) + assert isinstance(data['D'], torch.nn.Module) + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in ['G', 'D', 'G_ema']: + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) + fp16_kwargs.num_fp16_res = 4 + fp16_kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + return data + +#---------------------------------------------------------------------------- + +class _TFNetworkStub(dnnlib.EasyDict): + pass + +class _LegacyUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'dnnlib.tflib.network' and name == 'Network': + return _TFNetworkStub + return super().find_class(module, name) + +#---------------------------------------------------------------------------- + +def _collect_tf_params(tf_net): + # pylint: disable=protected-access + tf_params = dict() + def recurse(prefix, tf_net): + for name, value in tf_net.variables: + tf_params[prefix + name] = value + for name, comp in tf_net.components.items(): + recurse(prefix + name + '/', comp) + recurse('', tf_net) + return tf_params + +#---------------------------------------------------------------------------- + +def _populate_module_params(module, *patterns): + for name, tensor in misc.named_params_and_buffers(module): + found = False + value = None + for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): + match = re.fullmatch(pattern, name) + if match: + found = True + if value_fn is not None: + value = value_fn(*match.groups()) + break + try: + assert found + if value is not None: + tensor.copy_(torch.from_numpy(np.array(value))) + except: + print(name, list(tensor.shape)) + raise + +#---------------------------------------------------------------------------- + +def convert_tf_generator(tf_G): + if tf_G.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_G.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None, none=None): + known_kwargs.add(tf_name) + val = tf_kwargs.get(tf_name, default) + return val if val is not None else none + + # Convert kwargs. + from training import networks_stylegan2 + network_class = networks_stylegan2.Generator + kwargs = dnnlib.EasyDict( + z_dim = kwarg('latent_size', 512), + c_dim = kwarg('label_size', 0), + w_dim = kwarg('dlatent_size', 512), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + architecture = kwarg('architecture', 'skip'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + use_noise = kwarg('use_noise', True), + activation = kwarg('nonlinearity', 'lrelu'), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 8), + embed_features = kwarg('label_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('mapping_nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.01), + w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), + ), + ) + + # Check for unknown kwargs. + kwarg('truncation_psi') + kwarg('truncation_cutoff') + kwarg('style_mixing_prob') + kwarg('structure') + kwarg('conditioning') + kwarg('fused_modconv') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_G) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value + kwargs.synthesis.kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + G = network_class(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(G, + r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], + r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], + r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], + r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], + r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], + r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], + r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), + r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], + r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], + r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], + r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], + r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], + r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], + r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, + r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], + r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, + r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'.*\.resample_filter', None, + r'.*\.act_filter', None, + ) + return G + +#---------------------------------------------------------------------------- + +def convert_tf_discriminator(tf_D): + if tf_D.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_D.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None): + known_kwargs.add(tf_name) + return tf_kwargs.get(tf_name, default) + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + c_dim = kwarg('label_size', 0), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + architecture = kwarg('architecture', 'resnet'), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + cmap_dim = kwarg('mapping_fmaps', None), + block_kwargs = dnnlib.EasyDict( + activation = kwarg('nonlinearity', 'lrelu'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + freeze_layers = kwarg('freeze_layers', 0), + ), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 0), + embed_features = kwarg('mapping_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.1), + ), + epilogue_kwargs = dnnlib.EasyDict( + mbstd_group_size = kwarg('mbstd_group_size', None), + mbstd_num_channels = kwarg('mbstd_num_features', 1), + activation = kwarg('nonlinearity', 'lrelu'), + ), + ) + + # Check for unknown kwargs. + kwarg('structure') + kwarg('conditioning') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_D) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value + kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks_stylegan2 + D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(D, + r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], + r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], + r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), + r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], + r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], + r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), + r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], + r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), + r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], + r'.*\.resample_filter', None, + ) + return D + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--source', help='Input pickle', required=True, metavar='PATH') +@click.option('--dest', help='Output pickle', required=True, metavar='PATH') +@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) +def convert_network_pickle(source, dest, force_fp16): + """Convert legacy network pickle into the native PyTorch format. + + The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. + It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. + + Example: + + \b + python legacy.py \\ + --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ + --dest=stylegan2-cat-config-f.pkl + """ + print(f'Loading "{source}"...') + with dnnlib.util.open_url(source) as f: + data = load_network_pkl(f, force_fp16=force_fp16) + print(f'Saving "{dest}"...') + with open(dest, 'wb') as f: + pickle.dump(data, f) + print('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_network_pickle() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/__init__.py b/3DPortraitGAN_pyramid/metrics/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/3DPortraitGAN_pyramid/metrics/equivariance.py b/3DPortraitGAN_pyramid/metrics/equivariance.py new file mode 100644 index 0000000..4609296 --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/equivariance.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper +"Alias-Free Generative Adversarial Networks".""" + +import copy +import numpy as np +import torch +import torch.fft +from torch_utils.ops import upfirdn2d +from . import metric_utils + +#---------------------------------------------------------------------------- +# Utilities. + +def sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, sinc(x), torch.zeros_like(x)) + +def rotation_matrix(angle): + angle = torch.as_tensor(angle).to(torch.float32) + mat = torch.eye(3, device=angle.device) + mat[0, 0] = angle.cos() + mat[0, 1] = angle.sin() + mat[1, 0] = -angle.sin() + mat[1, 1] = angle.cos() + return mat + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.1. + +def apply_integer_translation(x, tx, ty): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.round().to(torch.int64) + iy = ty.round().to(torch.int64) + + z = torch.zeros_like(x) + m = torch.zeros_like(x) + if abs(ix) < W and abs(iy) < H: + y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] + z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y + m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.2. + +def apply_fractional_translation(x, tx, ty, a=3): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.floor().to(torch.int64) + iy = ty.floor().to(torch.int64) + fx = tx - ix + fy = ty - iy + b = a - 1 + + z = torch.zeros_like(x) + zx0 = max(ix - b, 0) + zy0 = max(iy - b, 0) + zx1 = min(ix + a, 0) + W + zy1 = min(iy + a, 0) + H + if zx0 < zx1 and zy0 < zy1: + taps = torch.arange(a * 2, device=x.device) - b + filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) + filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) + y = x + y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) + y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) + y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] + z[:, :, zy0:zy1, zx0:zx1] = y + + m = torch.zeros_like(x) + mx0 = max(ix + a, 0) + my0 = max(iy + a, 0) + mx1 = min(ix - b, 0) + W + my1 = min(iy - b, 0) + H + if mx0 < mx1 and my0 < my1: + m[:, :, my0:my1, mx0:mx1] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Construct an oriented low-pass filter that applies the appropriate +# bandlimit with respect to the input and output of the given affine 2D +# image transformation. + +def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) + fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = lanczos_window(xi, a) * lanczos_window(yi, a) + wo = lanczos_window(xo, a) * lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- +# Apply the given affine transformation to a batch of 2D images. + +def apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- +# Apply fractional rotation to a batch of 2D images. Corresponds to the +# operator R_\alpha in Appendix E.3. + +def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(angle) + return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) + +#---------------------------------------------------------------------------- +# Modify the frequency content of a batch of 2D images as if they had undergo +# fractional rotation -- but without actually rotating them. Corresponds to +# the operator R^*_\alpha in Appendix E.3. + +def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(-angle) + f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) + y = upfirdn2d.filter2d(x=x, f=f) + m = torch.zeros_like(y) + c = f.shape[0] // 2 + m[:, :, c:-c, c:-c] = 1 + return y, m + +#---------------------------------------------------------------------------- +# Compute the selected equivariance metrics for the given generator. + +def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): + assert compute_eqt_int or compute_eqt_frac or compute_eqr + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + I = torch.eye(3, device=opts.device) + M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) + if M is None: + raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + sums = None + progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + s = [] + + # Randomize noise buffers, if any. + for name, buf in G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Run mapping network. + z = torch.randn([batch_size, G.z_dim], device=opts.device) + c = next(c_iter) + ws = G.mapping(z=z, c=c) + + # Generate reference image. + M[:] = I + orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + + # Integer translation (EQ-T). + if compute_eqt_int: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + t = (t * G.img_resolution).round() / G.img_resolution + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_integer_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Fractional translation (EQ-T_frac). + if compute_eqt_frac: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_fractional_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Rotation (EQ-R). + if compute_eqr: + angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) + M[:] = rotation_matrix(-angle) + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, ref_mask = apply_fractional_rotation(orig, angle) + pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) + mask = ref_mask * pseudo_mask + s += [(ref - pseudo).square() * mask, mask] + + # Accumulate results. + s = torch.stack([x.to(torch.float64).sum() for x in s]) + sums = sums + s if sums is not None else s + progress.update(num_samples) + + # Compute PSNRs. + if opts.num_gpus > 1: + torch.distributed.all_reduce(sums) + sums = sums.cpu() + mses = sums[0::2] / sums[1::2] + psnrs = np.log10(2) * 20 - mses.log10() * 10 + psnrs = tuple(psnrs.numpy()) + return psnrs[0] if len(psnrs) == 1 else psnrs + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/frechet_inception_distance.py b/3DPortraitGAN_pyramid/metrics/frechet_inception_distance.py new file mode 100644 index 0000000..c2944eb --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/frechet_inception_distance.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Frechet Inception Distance (FID) from the paper +"GANs trained by a two time-scale update rule converge to a local Nash +equilibrium". Matches the original implementation by Heusel et al. at +https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" + +import numpy as np +import scipy.linalg +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_fid(opts, max_real, num_gen): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() + + mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() + + if opts.rank != 0: + return float('nan') + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + return float(fid) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/inception_score.py b/3DPortraitGAN_pyramid/metrics/inception_score.py new file mode 100644 index 0000000..1e5e247 --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/inception_score.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Inception Score (IS) from the paper "Improved techniques for training +GANs". Matches the original implementation by Salimans et al. at +https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_is(opts, num_gen, num_splits): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. + + gen_probs = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan'), float('nan') + + scores = [] + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/kernel_inception_distance.py b/3DPortraitGAN_pyramid/metrics/kernel_inception_distance.py new file mode 100644 index 0000000..48906eb --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/kernel_inception_distance.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Kernel Inception Distance (KID) from the paper "Demystifying MMD +GANs". Matches the original implementation by Binkowski et al. at +https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan') + + n = real_features.shape[1] + m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) + t = 0 + for _subset_idx in range(num_subsets): + x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] + y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] + a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 + b = (x @ y.T / n + 1) ** 3 + t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m + kid = t / num_subsets / m + return float(kid) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/metric_main.py b/3DPortraitGAN_pyramid/metrics/metric_main.py new file mode 100644 index 0000000..ee00372 --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/metric_main.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Main API for computing and reporting quality metrics.""" + +import os +import time +import json +import torch +import dnnlib + +from . import metric_utils +from . import frechet_inception_distance +from . import kernel_inception_distance +from . import precision_recall +from . import perceptual_path_length +from . import inception_score +from . import equivariance + +#---------------------------------------------------------------------------- + +_metric_dict = dict() # name => fn + +def register_metric(fn): + assert callable(fn) + _metric_dict[fn.__name__] = fn + return fn + +def is_valid_metric(metric): + return metric in _metric_dict + +def list_valid_metrics(): + return list(_metric_dict.keys()) + +#---------------------------------------------------------------------------- + +def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. + assert is_valid_metric(metric) + opts = metric_utils.MetricOptions(**kwargs) + + # Calculate. + start_time = time.time() + results = _metric_dict[metric](opts) + total_time = time.time() - start_time + + # Broadcast results. + for key, value in list(results.items()): + if opts.num_gpus > 1: + value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) + torch.distributed.broadcast(tensor=value, src=0) + value = float(value.cpu()) + results[key] = value + + # Decorate with metadata. + return dnnlib.EasyDict( + results = dnnlib.EasyDict(results), + metric = metric, + total_time = total_time, + total_time_str = dnnlib.util.format_time(total_time), + num_gpus = opts.num_gpus + ) + +#---------------------------------------------------------------------------- + +def report_metric(result_dict, run_dir=None, snapshot_pkl=None): + metric = result_dict['metric'] + assert is_valid_metric(metric) + if run_dir is not None and snapshot_pkl is not None: + snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) + + jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) + print(jsonl_line) + if run_dir is not None and os.path.isdir(run_dir): + with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: + f.write(jsonl_line + '\n') + +#---------------------------------------------------------------------------- +# Recommended metrics. + +@register_metric +def fid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=True, back_repeat=1) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) + return dict(fid50k_full=fid) + +@register_metric +def kid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k_full=kid) + +@register_metric +def pr50k3_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) + +@register_metric +def ppl2_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) + return dict(ppl2_wend=ppl) + +@register_metric +def eqt50k_int(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) + return dict(eqt50k_int=psnr) + +@register_metric +def eqt50k_frac(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) + return dict(eqt50k_frac=psnr) + +@register_metric +def eqr50k(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) + return dict(eqr50k=psnr) + +#---------------------------------------------------------------------------- +# Legacy metrics. + +@register_metric +def fid50k(opts): + opts.dataset_kwargs.update(max_size=None) + fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) + return dict(fid50k=fid) + +@register_metric +def kid50k(opts): + opts.dataset_kwargs.update(max_size=None) + kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k=kid) + +@register_metric +def pr50k3(opts): + opts.dataset_kwargs.update(max_size=None) + precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_precision=precision, pr50k3_recall=recall) + +@register_metric +def is50k(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) + return dict(is50k_mean=mean, is50k_std=std) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/metric_utils.py b/3DPortraitGAN_pyramid/metrics/metric_utils.py new file mode 100644 index 0000000..5891537 --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/metric_utils.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Miscellaneous utilities used internally by the quality metrics.""" + +import os +import time +import hashlib +import pickle +import copy +import uuid +import numpy as np +import torch +import dnnlib + +#---------------------------------------------------------------------------- + +class MetricOptions: + def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, identical_c_p = True, + cache=True, metric_pose_sample_mode = None,D = None,pose_predict_kwargs = None): + assert 0 <= rank < num_gpus + self.G = G + self.G_kwargs = dnnlib.EasyDict(G_kwargs) + self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) + self.num_gpus = num_gpus + self.rank = rank + self.device = device if device is not None else torch.device('cuda', rank) + self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() + self.cache = cache + + self.metric_pose_sample_mode = metric_pose_sample_mode + self.D = D + self.pose_predict_kwargs = pose_predict_kwargs + + self.identical_c_p = identical_c_p + +#---------------------------------------------------------------------------- + +_feature_detector_cache = dict() + +def get_feature_detector_name(url): + return os.path.splitext(url.split('/')[-1])[0] + +def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): + assert 0 <= rank < num_gpus + key = (url, device) + if key not in _feature_detector_cache: + is_leader = (rank == 0) + if not is_leader and num_gpus > 1: + torch.distributed.barrier() # leader goes first + with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: + _feature_detector_cache[key] = pickle.load(f).to(device) + if is_leader and num_gpus > 1: + torch.distributed.barrier() # others follow + return _feature_detector_cache[key] + +#---------------------------------------------------------------------------- + +def iterate_random_labels(opts, batch_size): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + random_idx = [np.random.randint(len(dataset)) for _i in range(batch_size) ] + + + c = [dataset.get_label(idx) for idx in random_idx] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + + p = [dataset.get_coarse_pose(idx) for idx in random_idx] + p = torch.from_numpy(np.stack(p)).pin_memory().to(opts.device) + yield c,p + + +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing + +def run_D_pose_prediction(img, c, blur_sigma=0,D = None): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + pose,_ = D.predict_pose( img, c) + return pose + +def get_pose_params(real_img,real_seg, real_c,D = None,neural_rendering_resolution = None,blur_sigma = None,resample_filter = None, filter_mode = None): + + + + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + if True: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, 'image_mask': real_img_tmp_image_mask} + + predicted_real_pose = run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma, D = D) + return predicted_real_pose + +def iterate_random_labels_predicted_pose(opts, batch_size,G): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + random_idx = [np.random.randint(len(dataset)) for _i in range(batch_size) ] + + + c = [dataset.get_label(idx) for idx in random_idx] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + + z = torch.randn([batch_size, opts.G.z_dim], device=opts.device) + + p = G.get_pose_params(z,c) + + + yield c,p + +def iterate_random_labels_predicted_pose_D(opts, batch_size,D): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + random_idx = [np.random.randint(len(dataset)) for _i in range(batch_size) ] + + + c = [dataset.get_label(idx) for idx in random_idx] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + + # p = [dataset.get_coarse_pose(idx) for idx in random_idx] + # p = torch.from_numpy(np.stack(p)).pin_memory().to(opts.device) + + image = [dataset.get_image(idx) for idx in random_idx] + image = torch.from_numpy(np.stack(image)).pin_memory().to(opts.device) + image = image.to(torch.float32) / 127.5 - 1 + + + mask = [dataset._seg_dataset.get_image(idx) for idx in random_idx] + mask = torch.from_numpy(np.stack(mask)).pin_memory().to(opts.device) + mask = mask.to(torch.float32) / 255.0 + + + + p = get_pose_params( + real_img = image, + real_seg = mask, + real_c = c, + D = D, + blur_sigma = opts.pose_predict_kwargs['blur_sigma'], + neural_rendering_resolution= opts.pose_predict_kwargs['neural_rendering_resolution'], + resample_filter= opts.pose_predict_kwargs['resample_filter'], + filter_mode= opts.pose_predict_kwargs['filter_mode'], + ) + yield c,p + +# def iterate_random_poses(opts, batch_size): +# if opts.G.c_dim == 0: +# p = torch.zeros([batch_size, 6], device=opts.device) +# while True: +# yield p +# else: +# dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) +# while True: +# p = [dataset.get_coarse_pose(np.random.randint(len(dataset))) for _i in range(batch_size)] +# p = torch.from_numpy(np.stack(p)).pin_memory().to(opts.device) +# yield p +#---------------------------------------------------------------------------- + +class FeatureStats: + def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): + self.capture_all = capture_all + self.capture_mean_cov = capture_mean_cov + self.max_items = max_items + self.num_items = 0 + self.num_features = None + self.all_features = None + self.raw_mean = None + self.raw_cov = None + + def set_num_features(self, num_features): + if self.num_features is not None: + assert num_features == self.num_features + else: + self.num_features = num_features + self.all_features = [] + self.raw_mean = np.zeros([num_features], dtype=np.float64) + self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) + + def is_full(self): + return (self.max_items is not None) and (self.num_items >= self.max_items) + + def append(self, x): + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): + if self.num_items >= self.max_items: + return + x = x[:self.max_items - self.num_items] + + self.set_num_features(x.shape[1]) + self.num_items += x.shape[0] + if self.capture_all: + self.all_features.append(x) + if self.capture_mean_cov: + x64 = x.astype(np.float64) + self.raw_mean += x64.sum(axis=0) + self.raw_cov += x64.T @ x64 + + def append_torch(self, x, num_gpus=1, rank=0): + assert isinstance(x, torch.Tensor) and x.ndim == 2 + assert 0 <= rank < num_gpus + if num_gpus > 1: + ys = [] + for src in range(num_gpus): + y = x.clone() + torch.distributed.broadcast(y, src=src) + ys.append(y) + x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples + self.append(x.cpu().numpy()) + + def get_all(self): + assert self.capture_all + return np.concatenate(self.all_features, axis=0) + + def get_all_torch(self): + return torch.from_numpy(self.get_all()) + + def get_mean_cov(self): + assert self.capture_mean_cov + mean = self.raw_mean / self.num_items + cov = self.raw_cov / self.num_items + cov = cov - np.outer(mean, mean) + return mean, cov + + def save(self, pkl_file): + with open(pkl_file, 'wb') as f: + pickle.dump(self.__dict__, f) + + @staticmethod + def load(pkl_file): + with open(pkl_file, 'rb') as f: + s = dnnlib.EasyDict(pickle.load(f)) + obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) + obj.__dict__.update(s) + return obj + +#---------------------------------------------------------------------------- + +class ProgressMonitor: + def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): + self.tag = tag + self.num_items = num_items + self.verbose = verbose + self.flush_interval = flush_interval + self.progress_fn = progress_fn + self.pfn_lo = pfn_lo + self.pfn_hi = pfn_hi + self.pfn_total = pfn_total + self.start_time = time.time() + self.batch_time = self.start_time + self.batch_items = 0 + if self.progress_fn is not None: + self.progress_fn(self.pfn_lo, self.pfn_total) + + def update(self, cur_items): + assert (self.num_items is None) or (cur_items <= self.num_items) + if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): + return + cur_time = time.time() + total_time = cur_time - self.start_time + time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) + if (self.verbose) and (self.tag is not None): + print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') + self.batch_time = cur_time + self.batch_items = cur_items + + if (self.progress_fn is not None) and (self.num_items is not None): + self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) + + def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): + return ProgressMonitor( + tag = tag, + num_items = num_items, + flush_interval = flush_interval, + verbose = self.verbose, + progress_fn = self.progress_fn, + pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, + pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, + pfn_total = self.pfn_total, + ) + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + # Try to lookup from cache. + cache_file = None + if opts.cache: + # Choose cache file name. + args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) + md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) + cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' + cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') + + # Check if the file exists (all processes must agree). + flag = os.path.isfile(cache_file) if opts.rank == 0 else False + if opts.num_gpus > 1: + flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) + torch.distributed.broadcast(tensor=flag, src=0) + flag = (float(flag.cpu()) != 0) + + # Load. + if flag: + return FeatureStats.load(cache_file) + + # Initialize. + num_items = len(dataset) + if max_items is not None: + num_items = min(num_items, max_items) + stats = FeatureStats(max_items=num_items, **stats_kwargs) + progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] + for images, masks, _labels,_poses in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images.to(opts.device), **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + + # Save to cache. + if cache_file is not None and opts.rank == 0: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + temp_file = cache_file + '.' + uuid.uuid4().hex + stats.save(temp_file) + os.replace(temp_file, cache_file) # atomic + return stats + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs): + if batch_gen is None: + batch_gen = min(batch_size, 8) + assert batch_size % batch_gen == 0 + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + G.set_batch_size(batch_gen) + + if opts.metric_pose_sample_mode == 'G_predict': + label_iter = iterate_random_labels_predicted_pose(opts=opts, batch_size=batch_gen, G = G) + else: + D = copy.deepcopy(opts.D).eval().requires_grad_(False).to(opts.device) + label_iter = iterate_random_labels_predicted_pose_D(opts=opts, batch_size=batch_gen,D = D) + + if not opts.identical_c_p: + if opts.metric_pose_sample_mode == 'G_predict': + cond_label_iter = iterate_random_labels_predicted_pose(opts=opts, batch_size=batch_gen, G = G) + else: + D = copy.deepcopy(opts.D).eval().requires_grad_(False).to(opts.device) + cond_label_iter = iterate_random_labels_predicted_pose_D(opts=opts, batch_size=batch_gen,D = D) + + + # Initialize. + stats = FeatureStats(**stats_kwargs) + assert stats.max_items is not None + progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + while not stats.is_full(): + images = [] + for _i in range(batch_size // batch_gen): + z = torch.randn([batch_gen, G.z_dim], device=opts.device) + + if opts.identical_c_p: + c,p = next(label_iter) + + img = G(z=z, c=c, pose_params = p,apply_def = True,**opts.G_kwargs)['image'] + else: + c,p = next(label_iter) + cond_c,cond_p = next(cond_label_iter) + ws = G.mapping(z, cond_c, cond_p) + img = G.synthesis(ws, c=c,apply_def = True, pose_params = p,**opts.G_kwargs )['image'] + + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + images.append(img) + images = torch.cat(images) + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images, **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + return stats + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/perceptual_path_length.py b/3DPortraitGAN_pyramid/metrics/perceptual_path_length.py new file mode 100644 index 0000000..5e58dac --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/perceptual_path_length.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator +Architecture for Generative Adversarial Networks". Matches the original +implementation by Karras et al. at +https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" + +import copy +import numpy as np +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + +#---------------------------------------------------------------------------- + +class PPLSampler(torch.nn.Module): + def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__() + self.G = copy.deepcopy(G) + self.G_kwargs = G_kwargs + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.crop = crop + self.vgg16 = copy.deepcopy(vgg16) + + def forward(self, c): + # Generate random latents and interpolation t-values. + t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) + z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) + + # Interpolate in W or Z. + if self.space == 'w': + w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) + wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) + wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) + else: # space == 'z' + zt0 = slerp(z0, z1, t.unsqueeze(1)) + zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) + wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) + + # Randomize noise buffers. + for name, buf in self.G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Generate images. + img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) + + # Center crop. + if self.crop: + assert img.shape[2] == img.shape[3] + c = img.shape[2] // 8 + img = img[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample to 256x256. + factor = self.G.img_resolution // 256 + if factor > 1: + img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) + + # Scale dynamic range from [-1,1] to [0,255]. + img = (img + 1) * (255 / 2) + if self.G.img_channels == 1: + img = img.repeat([1, 3, 1, 1]) + + # Evaluate differential LPIPS. + lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) + dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 + return dist + +#---------------------------------------------------------------------------- + +def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): + vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) + + # Setup sampler and labels. + sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) + sampler.eval().requires_grad_(False).to(opts.device) + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + dist = [] + progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + x = sampler(next(c_iter)) + for src in range(opts.num_gpus): + y = x.clone() + if opts.num_gpus > 1: + torch.distributed.broadcast(y, src=src) + dist.append(y) + progress.update(num_samples) + + # Compute PPL. + if opts.rank != 0: + return float('nan') + dist = torch.cat(dist)[:num_samples].cpu().numpy() + lo = np.percentile(dist, 1, interpolation='lower') + hi = np.percentile(dist, 99, interpolation='higher') + ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() + return float(ppl) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/metrics/precision_recall.py b/3DPortraitGAN_pyramid/metrics/precision_recall.py new file mode 100644 index 0000000..e33e85f --- /dev/null +++ b/3DPortraitGAN_pyramid/metrics/precision_recall.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Precision/Recall (PR) from the paper "Improved Precision and Recall +Metric for Assessing Generative Models". Matches the original implementation +by Kynkaanniemi et al. at +https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" + +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): + assert 0 <= rank < num_gpus + num_cols = col_features.shape[0] + num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus + col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) + dist_batches = [] + for col_batch in col_batches[rank :: num_gpus]: + dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] + for src in range(num_gpus): + dist_broadcast = dist_batch.clone() + if num_gpus > 1: + torch.distributed.broadcast(dist_broadcast, src=src) + dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) + return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None + +#---------------------------------------------------------------------------- + +def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + detector_kwargs = dict(return_features=True) + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) + + results = dict() + for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: + kth = [] + for manifold_batch in manifold.split(row_batch_size): + dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) + kth = torch.cat(kth) if opts.rank == 0 else None + pred = [] + for probes_batch in probes.split(row_batch_size): + dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) + results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') + return results['precision'], results['recall'] + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/optimize_trigrid.py b/3DPortraitGAN_pyramid/optimize_trigrid.py new file mode 100644 index 0000000..b67b9d5 --- /dev/null +++ b/3DPortraitGAN_pyramid/optimize_trigrid.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile +import json +import legacy + +from camera_utils import LookAtPoseSampler,FOV_to_intrinsics +from torch_utils import misc +import glob +import PIL +from torch.utils.data import DataLoader +import torch.nn.functional as F + +#---------------------------------------------------------------------------- + +class Dataset(torch.utils.data.Dataset): + def __init__(self, path): + camera_info_path = os.path.join(path, 'data', 'camera_info.json') + with open(camera_info_path, 'r') as f: + camera_info = json.load(f) + + self.camera_info = camera_info + + image_list = list(camera_info.keys()) + self.image_list = [] + for img_name in image_list: + if os.path.exists(os.path.join(path, 'update_data', img_name)): + self.image_list.append(img_name) + + self.image_dir = os.path.join(path, 'update_data') + + + + def __len__(self): + return len(self.image_list) + + def __getitem__(self, index): + img_name = self.image_list[index] + + img_path = os.path.join(self.image_dir, img_name) + + img = imageio.imread(img_path) + img = np.array(img).astype(np.float32) + img = img / 255.0 + # to -1,1 + img = img * 2 - 1 + img = torch.from_numpy(img) # [H, W, C] + + + camera_info = self.camera_info[img_name] + camera_info = torch.from_numpy(np.array(camera_info)).float().squeeze() + + return img, camera_info + + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--data_dir', help='Network pickle filename', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image_depth', 'image_raw']), required=False, metavar='STR', default='image_raw', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) + +def generate_images( + network_pkl: str, + data_dir: str, + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], +): + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution']) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance']) + + G.rendering_kwargs['ray_start'] = 2.35 + + + + print("Reloading Modules!") + from training.neural_renderer import TriPlaneGenerator + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=False) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + G.set_batch_size(1) + + intrinsics = FOV_to_intrinsics(12.447863, device=device) + cam_pivot = torch.tensor([0, 0.0649, 0], device=device) + cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) + default_cam2world_pose = LookAtPoseSampler.sample(np.pi / 2, np.pi / 2, cam_pivot, + radius=cam_radius, device=device) + default_cam_params = torch.cat([default_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + res_dir = data_dir + + update_data_dir = os.path.join(res_dir, 'update_data') + if not os.path.exists(update_data_dir): + print('update data not found for ', res_dir) + return + + print('optimize for ', res_dir) + + log_dir = os.path.join(res_dir, 'log') + os.makedirs(log_dir, exist_ok=True) + log_img_dir = os.path.join(log_dir, 'img') + os.makedirs(log_img_dir, exist_ok=True) + + log_ckpt_dir = os.path.join(log_dir, 'ckpt') + os.makedirs(log_ckpt_dir, exist_ok=True) + + + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + ckpt_path = os.path.join(res_dir, 'checkpoints/df.pth') + if not os.path.exists(ckpt_path): + print('checkpoints not found for ', res_dir) + return + + print('Loading checkpoints from "%s"...' % ckpt_path) + ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['model'] + trigrid = { + 8: ckpt['trigrids_8'].to(device).requires_grad_(True), + 16: ckpt['trigrids_16'].to(device).requires_grad_(True), + 32: ckpt['trigrids_32'].to(device).requires_grad_(True), + 64: ckpt['trigrids_64'].to(device).requires_grad_(True), + 128: ckpt['trigrids_128'].to(device).requires_grad_(True), + 256: ckpt['trigrids_256'].to(device).requires_grad_(True), + 512: ckpt['trigrids_512'].to(device).requires_grad_(True), + } + ws = ckpt['ws'].to(device) + + epoch_num = 19 + patch_resolution = 256 + lr = 1.0 + params = [ + {'params': trigrid[8], 'lr': lr}, + {'params': trigrid[16], 'lr': lr}, + {'params': trigrid[32], 'lr': lr}, + {'params': trigrid[64], 'lr': lr}, + {'params': trigrid[128], 'lr': lr}, + {'params': trigrid[256], 'lr': lr}, + {'params': trigrid[512], 'lr': lr}, + ] + # optimizer = torch.optim.Adam(params, betas=(0.9, 0.999)) + + from optimizer import Adan + + # Adan usually requires a larger LR + optimizer = Adan(params, eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) + + + dataset = Dataset(res_dir) + + data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, drop_last=True) + + for epoch in range(epoch_num): + print('epoch: ', epoch) + + for i, data in enumerate(data_loader): + print('iter: ', i) + image, cam = data + + gt_img = image.clone().detach().to(device).permute(0, 3, 1, 2) # 1, 3, 512, 512 [-1,1] + cam = cam.clone().detach().to(device) + #print('fetch data done') + # render + output = G.render_planes(ws=ws, planes=trigrid, c=cam, noise_mode='const', + neural_rendering_resolution=512, chunk=4096,render_bg = False, patch_resolution = patch_resolution) + + img = output['image_raw'] # 1, 3, 512, 512 [-1,1] + mask = output['image_mask'] # 1, 1, 512, 512 [0,1] + patch_info = output['patch_info'] + + + # L2 loss + + top, left = patch_info[0] + gt_img = gt_img[:, :, top:top + patch_resolution, left:left + patch_resolution] + + loss = torch.mean((img - gt_img) ** 2 * mask)*1e3 + print('loss: ', loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + + + # save checkpoint + if epoch == epoch_num - 1: + ckpt = { + 'trigrids_8': trigrid[8].clone().detach(), + 'trigrids_16': trigrid[16].clone().detach(), + 'trigrids_32': trigrid[32].clone().detach(), + 'trigrids_64': trigrid[64].clone().detach(), + 'trigrids_128': trigrid[128].clone().detach(), + 'trigrids_256': trigrid[256].clone().detach(), + 'trigrids_512': trigrid[512].clone().detach(), + 'ws': ws, + } + + torch.save({'model': ckpt}, f'{log_ckpt_dir}/epoch_{epoch:05d}.pth') + + with torch.no_grad(): + output = G.render_planes(ws=ws, planes=trigrid, c=default_cam_params, noise_mode='const', + neural_rendering_resolution=512, chunk=4096, render_bg=False, + patch_resolution=None) + + img = output['image_raw'] # 1, 3, 512, 512 [-1,1] + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + print('save image to ', f'{log_img_dir}/epoch_{epoch}.png') + PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{log_img_dir}/epoch_{epoch}.png') + + + + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/optimizer.py b/3DPortraitGAN_pyramid/optimizer.py new file mode 100644 index 0000000..f5bb64f --- /dev/null +++ b/3DPortraitGAN_pyramid/optimizer.py @@ -0,0 +1,325 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. (default: True) + """ + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + + clip_global_grad_norm = torch.clamp( + max_grad_norm / (global_grad_norm + group['eps']), + max=1.0).item() + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + neg_pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'neg_pre_grad' not in state or group['step'] == 1: + state['neg_pre_grad'] = p.grad.clone().mul_( + -clip_global_grad_norm) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + neg_pre_grads.append(state['neg_pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + neg_pre_grads=neg_pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + + if group['foreach']: + _multi_tensor_adan(**kwargs) + else: + _single_tensor_adan(**kwargs) + + return loss + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + neg_grad_or_diff = neg_pre_grads[i] + + grad.mul_(clip_global_grad_norm) + + # for memory saving, we use `neg_grad_or_diff` + # to get some temp variable in a inplace way + neg_grad_or_diff.add_(grad) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, + alpha=1 - beta2) # diff_t + + neg_grad_or_diff.mul_(beta2).add_(grad) + exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, + neg_grad_or_diff, + value=1 - beta3) # n_t + + denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + param.div_(1 + lr * weight_decay) + + neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if len(params) == 0: + return + + torch._foreach_mul_(grads, clip_global_grad_norm) + + # for memory saving, we use `neg_pre_grads` + # to get some temp variable in a inplace way + torch._foreach_add_(neg_pre_grads, grads) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, neg_pre_grads, + alpha=1 - beta2) # diff_t + + torch._foreach_mul_(neg_pre_grads, beta2) + torch._foreach_add_(neg_pre_grads, grads) + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_(exp_avg_sqs, + neg_pre_grads, + neg_pre_grads, + value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + else: + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + torch._foreach_div_(params, 1 + lr * weight_decay) + torch._foreach_zero_(neg_pre_grads) + torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/proj/configs/__init__.py b/3DPortraitGAN_pyramid/proj/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/3DPortraitGAN_pyramid/proj/configs/evaluation_config.py b/3DPortraitGAN_pyramid/proj/configs/evaluation_config.py new file mode 100644 index 0000000..16b621d --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/configs/evaluation_config.py @@ -0,0 +1 @@ +evaluated_methods = ['e4e', 'SG2', 'SG2Plus'] \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/proj/configs/global_config.py b/3DPortraitGAN_pyramid/proj/configs/global_config.py new file mode 100644 index 0000000..c3cdaa4 --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/configs/global_config.py @@ -0,0 +1,12 @@ +## Device +cuda_visible_devices = '0' +device = 'cuda:0' + +## Logs +training_step = 1 +image_rec_result_log_snapshot = 100 +pivotal_training_steps = 0 +model_snapshot_interval = 400 + +## Run name to be updated during PTI +run_name = 'test_pti' diff --git a/3DPortraitGAN_pyramid/proj/configs/hyperparameters.py b/3DPortraitGAN_pyramid/proj/configs/hyperparameters.py new file mode 100644 index 0000000..f5b2d01 --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/configs/hyperparameters.py @@ -0,0 +1,28 @@ +## Architechture +lpips_type = 'alex' +first_inv_type = 'w' +optim_type = 'adam' + +## Locality regularization +latent_ball_num_of_samples = 1 +locality_regularization_interval = 1 +use_locality_regularization = False +regulizer_l2_lambda = 0.1 +regulizer_lpips_lambda = 0.1 +regulizer_alpha = 30 + +## Loss +pt_l2_lambda = 1 +pt_lpips_lambda = 1 + +## Steps +LPIPS_value_threshold = 0.06 +max_pti_steps = 350 +first_inv_steps = 450 +max_images_to_invert = 30 + +## Optimization +pti_learning_rate = 15e-4 #3e-4 +first_inv_lr = 5e-3 +train_batch_size = 1 +use_last_w_pivots = False diff --git a/3DPortraitGAN_pyramid/proj/projector/camera_utils.py b/3DPortraitGAN_pyramid/proj/projector/camera_utils.py new file mode 100644 index 0000000..d97a59b --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/projector/camera_utils.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts. +""" + +import math + +import torch +import torch.nn as nn + +from training.volumetric_rendering import math_utils + +class GaussianCameraPoseSampler: + """ + Samples pitch and yaw from a Gaussian distribution and returns a camera pose. + Camera is specified as looking at the origin. + If horizontal and vertical stddev (specified in radians) are zero, gives a + deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean. + The coordinate system is specified with y-up, z-forward, x-left. + Horizontal mean is the azimuthal angle (rotation around y axis) in radians, + vertical mean is the polar angle (angle from the y axis) in radians. + A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2. + Example: + For a camera pose looking at the origin with the camera at position [0, 0, 1]: + cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean + v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + forward_vectors = math_utils.normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + + +class LookAtPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the + camera is specified as looking at 'lookat_position', a 3-vector. + Example: + For a camera pose looking at the origin with the camera at position [0, 0, 1]: + cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean + v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + # forward_vectors = math_utils.normalize_vecs(-camera_origins) + forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + +class UniformCameraPoseSampler: + """ + Same as GaussianCameraPoseSampler, except the + pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev. + Example: + For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians: + cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16) + """ + + @staticmethod + def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'): + h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean + v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + theta = h + v = v / math.pi + phi = torch.arccos(1 - 2*v) + + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta) + camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta) + camera_origins[:, 1:2] = radius*torch.cos(phi) + + forward_vectors = math_utils.normalize_vecs(-camera_origins) + return create_cam2world_matrix(forward_vectors, camera_origins) + +def create_cam2world_matrix(forward_vector, origin): + """ + Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix. + Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll. + """ + + forward_vector = math_utils.normalize_vecs(forward_vector) + up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector) + + right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) + up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1)) + + rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) + rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1) + + translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) + translation_matrix[:, :3, 3] = origin + cam2world = (translation_matrix @ rotation_matrix)[:, :, :] + assert(cam2world.shape[1:] == (4, 4)) + return cam2world + + +def FOV_to_intrinsics(fov_degrees, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + + focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414)) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/proj/projector/w_projector.py b/3DPortraitGAN_pyramid/proj/projector/w_projector.py new file mode 100644 index 0000000..d3f2f5e --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/projector/w_projector.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Project given image to the latent space of pretrained network pickle.""" + +import copy +import os +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm +import dnnlib +import PIL +from camera_utils import LookAtPoseSampler + +def project( + G, + c, + p, + outdir, + target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + *, + num_steps=1000, + w_avg_samples=10000, + initial_learning_rate=0.01, + initial_noise_factor=0.05, + lr_rampdown_length=0.25, + lr_rampup_length=0.05, + noise_ramp_length=0.75, + regularize_noise_weight=1e5, + verbose=False, + device: torch.device, + initial_w=None, + image_log_step=100, + w_name: str, + no_sr = False +): + os.makedirs(f'{outdir}/{w_name}_w', exist_ok=True) + outdir = f'{outdir}/{w_name}_w' + assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) + + def logprint(*args): + if verbose: + print(*args) + + G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore + + # Compute w stats. + w_avg_path = './w_avg.npy' + w_std_path = './w_std.npy' + if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)): + print(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + # c_samples = c.repeat(w_avg_samples, 1) + + # use avg look at point + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point, + radius=2.7, device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c_samples = c_samples.repeat(w_avg_samples, 1) + p_samples = p.repeat(w_avg_samples, 1) + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples, p_samples, ) # [N, L, C] + w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + # print('save w_avg to ./w_avg.npy') + # np.save('./w_avg.npy',w_avg) + w_avg_tensor = torch.from_numpy(w_avg).cuda() + w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + # np.save(w_avg_path, w_avg) + # np.save(w_std_path, w_std) + else: + # w_avg = np.load(w_avg_path) + # w_std = np.load(w_std_path) + raise Exception(' ') + + # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + # c_samples = c.repeat(w_avg_samples, 1) + # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] + # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + # w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + # w_avg_tensor = torch.from_numpy(w_avg).cuda() + # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + start_w = initial_w if initial_w is not None else w_avg + + # Setup noise inputs. + noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} + + # Load VGG16 feature detector. + #url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + url = './models/vgg16.pt' + with dnnlib.util.open_url(url) as f: + vgg16 = torch.jit.load(f).eval().to(device) + + # Features for target image. + target_images = target.unsqueeze(0).to(device).to(torch.float32) + if target_images.shape[2] > 256: + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_features = vgg16(target_images, resize_images=False, return_lpips=True) + + start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1) + w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, + requires_grad=True) # pylint: disable=not-callable + + optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), + lr=0.1) + + # Init noise. + for buf in noise_bufs.values(): + buf[:] = torch.randn_like(buf) + buf.requires_grad = True + + for step in tqdm(range(num_steps)): + + # Learning rate schedule. + t = step / num_steps + w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 + lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) + lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) + lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) + lr = initial_learning_rate * lr_ramp + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Synth images from opt_w. + w_noise = torch.randn_like(w_opt) * w_noise_scale + ws = (w_opt + w_noise) + # synth_images = G.synthesis(ws,c, noise_mode='const')['image'] + if no_sr: + synth_images = G.synthesis(ws, c=c, neural_rendering_resolution = 256, noise_mode='const', apply_def=True, pose_params=p)['image_raw'] + assert synth_images.shape[2] == 256 + else: + synth_images = G.synthesis(ws, c=c, noise_mode='const', apply_def=True, pose_params=p)['image'] + + if step % image_log_step == 0 or step == num_steps - 1: + with torch.no_grad(): + vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + + PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png') + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + synth_images = (synth_images + 1) * (255 / 2) + if synth_images.shape[2] > 256: + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) + dist = (target_features - synth_features).square().sum() + + # Noise regularization. + reg_loss = 0.0 + for v in noise_bufs.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + loss = dist + reg_loss * regularize_noise_weight + + # if step % 10 == 0: + # with torch.no_grad(): + # print({f'step {step}, first projection _{w_name}': loss.detach().cpu()}) + + # Step + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') + + # Normalize noise. + with torch.no_grad(): + for buf in noise_bufs.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + del G + return w_opt \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/proj/projector/w_projector_with_pose_optim.py b/3DPortraitGAN_pyramid/proj/projector/w_projector_with_pose_optim.py new file mode 100644 index 0000000..69b4b6f --- /dev/null +++ b/3DPortraitGAN_pyramid/proj/projector/w_projector_with_pose_optim.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Project given image to the latent space of pretrained network pickle.""" + +import copy +import os +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm +import dnnlib +import PIL +from camera_utils import LookAtPoseSampler + +def project( + G, + c, + p, + outdir, + target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + *, + num_steps=1000, + w_avg_samples=10000, + initial_learning_rate=0.01, + initial_noise_factor=0.05, + lr_rampdown_length=0.25, + lr_rampup_length=0.05, + noise_ramp_length=0.75, + regularize_noise_weight=1e5, + verbose=False, + device: torch.device, + initial_w=None, + image_log_step=100, + w_name: str, + no_sr = False +): + os.makedirs(f'{outdir}/{w_name}_w', exist_ok=True) + outdir = f'{outdir}/{w_name}_w' + assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) + + def logprint(*args): + if verbose: + print(*args) + + G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore + + # Compute w stats. + w_avg_path = './w_avg.npy' + w_std_path = './w_std.npy' + if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)): + print(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + # c_samples = c.repeat(w_avg_samples, 1) + + # use avg look at point + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point, + radius=2.7, device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c_samples = c_samples.repeat(w_avg_samples, 1) + p_samples = p.repeat(w_avg_samples, 1) + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples, p_samples, ) # [N, L, C] + w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + # print('save w_avg to ./w_avg.npy') + # np.save('./w_avg.npy',w_avg) + w_avg_tensor = torch.from_numpy(w_avg).cuda() + w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + # np.save(w_avg_path, w_avg) + # np.save(w_std_path, w_std) + else: + # w_avg = np.load(w_avg_path) + # w_std = np.load(w_std_path) + raise Exception(' ') + + # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + # c_samples = c.repeat(w_avg_samples, 1) + # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] + # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + # w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + # w_avg_tensor = torch.from_numpy(w_avg).cuda() + # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + start_w = initial_w if initial_w is not None else w_avg + + # Setup noise inputs. + noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} + + # Load VGG16 feature detector. + #url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + url = './models/vgg16.pt' + with dnnlib.util.open_url(url) as f: + vgg16 = torch.jit.load(f).eval().to(device) + + # Features for target image. + target_images = target.unsqueeze(0).to(device).to(torch.float32) + if target_images.shape[2] > 256: + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_features = vgg16(target_images, resize_images=False, return_lpips=True) + + start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1) + w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, + requires_grad=True) # pylint: disable=not-callable + + p_opt = p.requires_grad_(True) + + params = [{'params': w_opt, 'lr': 0.1}, + {'params': list(noise_bufs.values()), 'lr': 0.1}, + {'params': p_opt, 'lr': 0.002} + ] + + optimizer = torch.optim.Adam(params, betas=(0.9, 0.999)) + + # Init noise. + for buf in noise_bufs.values(): + buf[:] = torch.randn_like(buf) + buf.requires_grad = True + + for step in tqdm(range(num_steps)): + + # Learning rate schedule. + t = step / num_steps + w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 + lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) + lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) + lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) + lr = initial_learning_rate * lr_ramp + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Synth images from opt_w. + w_noise = torch.randn_like(w_opt) * w_noise_scale + ws = (w_opt + w_noise) + # synth_images = G.synthesis(ws,c, noise_mode='const')['image'] + if no_sr: + synth_images = G.synthesis(ws, c=c, neural_rendering_resolution = 256, noise_mode='const', apply_def=True, pose_params=p_opt)['image_raw'] + assert synth_images.shape[2] == 256 + else: + synth_images = G.synthesis(ws, c=c, noise_mode='const', apply_def=True, pose_params=p_opt)['image'] + + if step % image_log_step == 0 or step == num_steps - 1: + with torch.no_grad(): + vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + + PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png') + + if step == num_steps - 1: + frontal_c = torch.tensor([[1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.7287e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.7287e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]], device=device, dtype=torch.float32) + + synth_images = G.synthesis(ws, c=frontal_c, noise_mode='const', apply_def=False, pose_params=None)['image'] + vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/canonical.png') + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + synth_images = (synth_images + 1) * (255 / 2) + if synth_images.shape[2] > 256: + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) + dist = (target_features - synth_features).square().sum() + + # Noise regularization. + reg_loss = 0.0 + for v in noise_bufs.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + loss = dist + reg_loss * regularize_noise_weight + + # if step % 10 == 0: + # with torch.no_grad(): + # print({f'step {step}, first projection _{w_name}': loss.detach().cpu()}) + + # Step + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') + + # Normalize noise. + with torch.no_grad(): + for buf in noise_bufs.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + del G + return w_opt,p_opt \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/pyramid_trigrid_visualizer.py b/3DPortraitGAN_pyramid/pyramid_trigrid_visualizer.py new file mode 100644 index 0000000..c09afed --- /dev/null +++ b/3DPortraitGAN_pyramid/pyramid_trigrid_visualizer.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import click +import os + +import multiprocessing +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_window +from gui_utils import imgui_utils +from gui_utils import gl_utils +from gui_utils import text_utils +from viz import renderer +from viz import pickle_widget +from viz import pyramid_trigrid_widget +from viz import performance_widget +from viz import capture_widget +from viz import backbone_cache_widget +from viz import layer_widget +from viz import pose_widget +# from viz import body_pose_widget +from viz import zoom_widget +from viz import conditioning_pose_widget +from viz import render_type_widget +from viz import render_depth_sample_widget + +#---------------------------------------------------------------------------- + +class Visualizer(imgui_window.ImguiWindow): + def __init__(self, capture_dir=None): + super().__init__(title='Cat Machine', window_width=3840, window_height=2160) + + # Internals. + self._last_error_print = None + self._async_renderer = AsyncRenderer() + self._defer_rendering = 0 + self._tex_img = None + self._tex_obj = None + + # Widget interface. + self.args = dnnlib.EasyDict() + self.result = dnnlib.EasyDict() + self.pane_w = 0 + self.label_w = 0 + self.button_w = 0 + + # Widgets. + self.pickle_widget = pickle_widget.PickleWidget(self) + self.pyramid_trigrid_widget = pyramid_trigrid_widget.PyramidTrigridWidget(self) + self.perf_widget = performance_widget.PerformanceWidget(self) + self.capture_widget = capture_widget.CaptureWidget(self) + self.backbone_cache_widget = backbone_cache_widget.BackboneCacheWidget(self) + self.layer_widget = layer_widget.LayerWidget(self) + self.pose_widget = pose_widget.PoseWidget(self) + # self.body_pose_widget = body_pose_widget.BodyPoseWidget(self) + self.zoom_widget = zoom_widget.ZoomWidget(self) + self.conditioning_pose_widget = conditioning_pose_widget.ConditioningPoseWidget(self) + self.render_type_widget = render_type_widget.RenderTypeWidget(self) + self.render_depth_sample_widget = render_depth_sample_widget.RenderDepthSampleWidget(self) + + if capture_dir is not None: + self.capture_widget.path = capture_dir + + # Initialize window. + self.set_position(0, 0) + self._adjust_font_size() + self.skip_frame() # Layout may change after first frame. + + def close(self): + super().close() + if self._async_renderer is not None: + self._async_renderer.close() + self._async_renderer = None + + def add_recent_pickle(self, pkl, ignore_errors=False): + self.pickle_widget.add_recent(pkl, ignore_errors=ignore_errors) + + def load_pickle(self, pkl, ignore_errors=False): + self.pickle_widget.load(pkl, ignore_errors=ignore_errors) + + def print_error(self, error): + error = str(error) + if error != self._last_error_print: + print('\n' + error + '\n') + self._last_error_print = error + + def defer_rendering(self, num_frames=1): + self._defer_rendering = max(self._defer_rendering, num_frames) + + def clear_result(self): + self._async_renderer.clear_result() + + def set_async(self, is_async): + if is_async != self._async_renderer.is_async: + self._async_renderer.set_async(is_async) + self.clear_result() + if 'image' in self.result: + self.result.message = 'Switching rendering process...' + self.defer_rendering() + + def _adjust_font_size(self): + old = self.font_size + self.set_font_size(min(self.content_width / 120, self.content_height / 60)) + if self.font_size != old: + self.skip_frame() # Layout changed. + + def draw_frame(self): + self.begin_frame() + self.args = dnnlib.EasyDict() + self.pane_w = self.font_size * 50 + self.button_w = self.font_size * 5 + self.label_w = round(self.font_size * 5.5) + + # Detect mouse dragging in the result area. + dragging, dx, dy = imgui_utils.drag_hidden_window('##result_area', x=self.pane_w, y=0, width=self.content_width-self.pane_w, height=self.content_height) + if dragging: + self.pose_widget.drag(dx, dy) + + # Begin control pane. + imgui.set_next_window_position(0, 0) + imgui.set_next_window_size(self.pane_w, self.content_height) + imgui.begin('##control_pane', closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) + + # Widgets. + expanded, _visible = imgui_utils.collapsing_header('Network & latent', default=True) + self.pickle_widget(expanded) + self.pyramid_trigrid_widget(expanded) + self.pose_widget(expanded) + self.zoom_widget(expanded) + self.conditioning_pose_widget(expanded) + # self.body_pose_widget(expanded) + self.render_type_widget(expanded) + self.render_depth_sample_widget(expanded) + expanded, _visible = imgui_utils.collapsing_header('Performance & capture', default=True) + self.perf_widget(expanded) + self.capture_widget(expanded) + expanded, _visible = imgui_utils.collapsing_header('Layers & channels', default=True) + self.backbone_cache_widget(expanded) + self.layer_widget(expanded) + + # Render. + if self.is_skipping_frames(): + pass + elif self._defer_rendering > 0: + self._defer_rendering -= 1 + elif self.args.pkl is not None: + self._async_renderer.set_args(**self.args) + result = self._async_renderer.get_result() + if result is not None: + self.result = result + + # Display. + max_w = self.content_width - self.pane_w + max_h = self.content_height + pos = np.array([self.pane_w + max_w / 2, max_h / 2]) + if 'image' in self.result: + if self._tex_img is not self.result.image: + self._tex_img = self.result.image + if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img): + self._tex_obj = gl_utils.Texture(image=self._tex_img, bilinear=False, mipmap=False) + else: + self._tex_obj.update(self._tex_img) + zoom = min(max_w / self._tex_obj.width, max_h / self._tex_obj.height) + # print(zoom) + zoom = np.floor(zoom) if zoom >= 1 else zoom + # zoom = 1 + self._tex_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True) + if 'error' in self.result: + self.print_error(self.result.error) + if 'message' not in self.result: + self.result.message = str(self.result.error) + if 'message' in self.result: + tex = text_utils.get_texture(self.result.message, size=self.font_size, max_width=max_w, max_height=max_h, outline=2) + tex.draw(pos=pos, align=0.5, rint=True, color=1) + + # End frame. + self._adjust_font_size() + imgui.end() + self.end_frame() + +#---------------------------------------------------------------------------- + +class AsyncRenderer: + def __init__(self): + self._closed = False + self._is_async = False + self._cur_args = None + self._cur_result = None + self._cur_stamp = 0 + self._renderer_obj = None + self._args_queue = None + self._result_queue = None + self._process = None + + def close(self): + self._closed = True + self._renderer_obj = None + if self._process is not None: + self._process.terminate() + self._process = None + self._args_queue = None + self._result_queue = None + + @property + def is_async(self): + return self._is_async + + def set_async(self, is_async): + self._is_async = is_async + + def set_args(self, **args): + assert not self._closed + if args != self._cur_args: + if self._is_async: + self._set_args_async(**args) + else: + self._set_args_sync(**args) + self._cur_args = args + + def _set_args_async(self, **args): + if self._process is None: + self._args_queue = multiprocessing.Queue() + self._result_queue = multiprocessing.Queue() + try: + multiprocessing.set_start_method('spawn') + except RuntimeError: + pass + self._process = multiprocessing.Process(target=self._process_fn, args=(self._args_queue, self._result_queue), daemon=True) + self._process.start() + self._args_queue.put([args, self._cur_stamp]) + + def _set_args_sync(self, **args): + if self._renderer_obj is None: + self._renderer_obj = renderer.Renderer() + self._cur_result = self._renderer_obj.render(**args) + + def get_result(self): + assert not self._closed + if self._result_queue is not None: + while self._result_queue.qsize() > 0: + result, stamp = self._result_queue.get() + if stamp == self._cur_stamp: + self._cur_result = result + return self._cur_result + + def clear_result(self): + assert not self._closed + self._cur_args = None + self._cur_result = None + self._cur_stamp += 1 + + @staticmethod + def _process_fn(args_queue, result_queue): + renderer_obj = renderer.Renderer() + cur_args = None + cur_stamp = None + while True: + args, stamp = args_queue.get() + while args_queue.qsize() > 0: + args, stamp = args_queue.get() + if args != cur_args or stamp != cur_stamp: + result = renderer_obj.render(**args) + if 'error' in result: + result.error = renderer.CapturedException(result.error) + result_queue.put([result, stamp]) + cur_args = args + cur_stamp = stamp + +#---------------------------------------------------------------------------- + +@click.command() +@click.argument('pkls', metavar='PATH', nargs=-1) +@click.option('--capture-dir', help='Where to save screenshot captures', metavar='PATH', default=None) +@click.option('--browse-dir', help='Specify model path for the \'Browse...\' button', metavar='PATH') +def main( + pkls, + capture_dir, + browse_dir +): + """Interactive model visualizer. + + Optional PATH argument can be used specify which .pkl file to load. + """ + viz = Visualizer(capture_dir=capture_dir) + + if browse_dir is not None: + viz.pickle_widget.search_dirs = [browse_dir] + + # List pickles. + pretrained = [ + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/ffhq512-128.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/afhqcats512-128.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/ffhqrebalanced512-64.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/ffhqrebalanced512-128.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/eg3d/versions/1/files/shapenetcars128-64.pkl', + ] + + # Populate recent pickles list with pretrained model URLs. + for url in pretrained: + viz.add_recent_pickle(url) + + # Run. + while not viz.should_close(): + viz.draw_frame() + viz.close() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/run_inversion_with_pose_optimization.py b/3DPortraitGAN_pyramid/run_inversion_with_pose_optimization.py new file mode 100644 index 0000000..17ed6b6 --- /dev/null +++ b/3DPortraitGAN_pyramid/run_inversion_with_pose_optimization.py @@ -0,0 +1,273 @@ +import glob + +import numpy as np +import dnnlib +import legacy +from proj.projector import w_projector_with_pose_optim +from proj.configs import global_config, hyperparameters +from PIL import Image +import torch +import json +import os +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing + + +# ---------------------------------------------------------------------------- +class Space_Regulizer: + def __init__(self, original_G, lpips_net): + self.original_G = original_G + self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha + self.lpips_loss = lpips_net + + def get_morphed_w_code(self, new_w_code, fixed_w): + interpolation_direction = new_w_code - fixed_w + interpolation_direction_norm = torch.norm(interpolation_direction, p=2) + direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm + result_w = fixed_w + direction_to_move + self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code + + return result_w + + def get_image_from_ws(self, w_codes, G): + return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes]) + + def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False): + loss = 0.0 + + z_samples = np.random.randn(num_of_sampled_latents, self.original_G.z_dim) + w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None, + truncation_psi=0.5) + territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples] + + for w_code in territory_indicator_ws: + new_img = new_G.synthesis(w_code, noise_mode='none', force_fp32=True) + with torch.no_grad(): + old_img = self.original_G.synthesis(w_code, noise_mode='none', force_fp32=True) + + if hyperparameters.regulizer_l2_lambda > 0: + l2_loss_val = l2_loss.l2_loss(old_img, new_img) + + loss += l2_loss_val * hyperparameters.regulizer_l2_lambda + + if hyperparameters.regulizer_lpips_lambda > 0: + loss_lpips = self.lpips_loss(old_img, new_img) + loss_lpips = torch.mean(torch.squeeze(loss_lpips)) + + loss += loss_lpips * hyperparameters.regulizer_lpips_lambda + + return loss / len(territory_indicator_ws) + + def space_regulizer_loss(self, new_G, w_batch, use_wandb): + ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb) + return ret_val + + +def l2_loss(real_images, generated_images): + l2_criterion = torch.nn.MSELoss(reduction='mean') + loss = l2_criterion(real_images, generated_images) + return loss + + +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +def run_D_pose_prediction(img, c, blur_sigma=0, D=None): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + pose, _ = D.predict_pose(img, c) + return pose + + +def get_pose_params(real_img, real_seg, real_c, D=None, neural_rendering_resolution=None, blur_sigma=None, + resample_filter=None, filter_mode=None): + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=resample_filter, + filter_mode=filter_mode) + + if True: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, + 'image_mask': real_img_tmp_image_mask} + + predicted_real_pose = run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma, D=D) + return predicted_real_pose + + +if __name__ == '__main__': + # input_dir + import argparse + import random + parser = argparse.ArgumentParser() + + parser.add_argument('--model_pkl', type=str,default='./models/model.pkl') + parser.add_argument('--pose_prediction_kwargs_path', type=str,default='./models/model.json') + parser.add_argument('--inversion_name', type=str) + parser.add_argument('--with_pose_optim', action='store_true') + parser.add_argument('--test_data_dir', type=str, default='../test_data') + + opt = parser.parse_args() + model_pkl = opt.model_pkl + pose_prediction_kwargs_path = opt.pose_prediction_kwargs_path + inversion_name = opt.inversion_name + with_pose_optim = opt.with_pose_optim + test_data_dir = opt.test_data_dir + + sampling_multiplier = 2.0 + + print('Loading networks from "%s"...' % model_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(model_pkl) as f: + resume_data = legacy.load_network_pkl(f) + print('resume_data', resume_data.keys()) + G = resume_data['G_ema'].to(device) # type: ignore + D = resume_data['D_ema'].to(device) # type: ignore + + G.set_batch_size(1) + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int( + G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + print('Loading pose_prediction_kwargs from "%s"...' % pose_prediction_kwargs_path) + with open(pose_prediction_kwargs_path, 'r') as f: + pose_predict_kwargs = json.load(f) + + # teaser + todo = glob.glob(os.path.join(test_data_dir, '*')) + + for sub_dir in todo: + print(sub_dir) + samples_dir = os.path.join(sub_dir, 'samples') + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + new_crop_mask_samples_dir = os.path.join(sub_dir, 'samples_new_crop/mask') + res_dir = os.path.join(sub_dir, f'samples_new_crop/{inversion_name}') + if os.path.exists(new_crop_samples_dir) and len( + glob.glob(os.path.join(new_crop_samples_dir, f'{inversion_name}/*/inversion.pt'))) == 0: + input_dir = new_crop_samples_dir + + + # ---------------------------------------------------------------------------- + + + camera_path = os.path.join(input_dir, 'result.json') + print('Loading camera pose from "%s"...' % camera_path) + with open(camera_path, 'r') as f: + camera_poses = json.load(f) + + print('Loading images from "%s"...' % input_dir) + image_base_dir = os.path.join(input_dir, 'aligned_images') + mask_base_path = os.path.join(input_dir, 'mask') + + images = glob.glob(os.path.join(image_base_dir, '*')) + + print('images', images) + + for image_path in images[:1]: + image_name = os.path.basename(image_path) + mask_path = os.path.join(mask_base_path, image_name) + print('projecting image: "%s"' % image_path) + image = Image.open(image_path).convert('RGB') + mask = Image.open(mask_path) + # image_name = os.path.basename(paths_config.input_data_path) + camera_pose = camera_poses[image_name] + cam2world_pose = torch.tensor(camera_pose['camera_pose'], device=device) + focal_length = 6.5104166 + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + + with torch.no_grad(): + image_p = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + image_p = np.array(image_p) + image_p = image_p.transpose(2, 0, 1) + image_p = torch.tensor(image_p, device=device) + image_p = image_p.to(device).to(torch.float32) / 127.5 - 1 + image_p = image_p.unsqueeze(0) + + mask_p = np.array(mask)[:, :, None] + mask_p = mask_p.transpose(2, 0, 1) + mask_p = torch.tensor(mask_p, device=device) + mask_p = mask_p.to(device).to(torch.float32) / 255.0 + mask_p = mask_p.unsqueeze(0) + + resample_filter = pose_predict_kwargs['resample_filter'] + resample_filter = torch.tensor(resample_filter, device=device).to(torch.float32) + + p = get_pose_params(real_img=image_p, + real_seg=mask_p, + real_c=c, + D=D, + neural_rendering_resolution=pose_predict_kwargs['neural_rendering_resolution'], + blur_sigma=pose_predict_kwargs['blur_sigma'], + resample_filter=resample_filter, + filter_mode=pose_predict_kwargs['filter_mode']) + + # ---------------------------------------------------------------------------- + image_name = image_name[:-4] + # coach = SingleIDCoach(None, False, c, p) + # coach.train(image=image, image_name=image_name[:-4]) + w_path_dir = res_dir + os.makedirs(w_path_dir, exist_ok=True) + use_ball_holder = True + # for fname, image in tqdm(self.data_loader): + # image_name = fname[0] + + embedding_dir = f'{w_path_dir}/{image_name}' + os.makedirs(embedding_dir, exist_ok=True) + image.save(f'{embedding_dir}/original.png') + w_pivot = None + # if hyperparameters.use_last_w_pivots: + # w_pivot = self.load_inversions(w_path_dir, image_name) + # elif not hyperparameters.use_last_w_pivots or w_pivot is None: + # w_pivot = self.calc_inversions(image, image_name) + # image = torch.tensor(image, device=device) + if os.path.exists(f'{embedding_dir}/0.pt'): + w_pivot = torch.load(f'{embedding_dir}/0.pt').to(global_config.device) + else: + image = image.resize((G.img_resolution, G.img_resolution), Image.BILINEAR) + image = np.array(image) + image = image.transpose(2, 0, 1) + image = torch.tensor(image, device=device) + image = image.to(device).to(torch.float32) / 127.5 - 1 + image = image.unsqueeze(0) + id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + # id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + w_pivot, p_opt = w_projector_with_pose_optim.project(G, c, p, embedding_dir, id_image, + device=torch.device('cuda'), + w_avg_samples=600, + num_steps=500, + w_name=image_name, no_sr=False) + # w_pivot = w_pivot.detach().clone().to(global_config.device) + w_pivot = w_pivot.to(global_config.device) + torch.save(w_pivot, f'{embedding_dir}/inversion.pt') + + print('p_opt:', p_opt.shape) + + poses = { + 'pose': p_opt.detach().cpu().numpy().tolist(), + 'camera_pose': c.cpu().numpy().tolist() + } + with open(f'{embedding_dir}/pose.json', 'w') as f: + json.dump(poses, f) + + diff --git a/3DPortraitGAN_pyramid/run_trigrid_gen.py b/3DPortraitGAN_pyramid/run_trigrid_gen.py new file mode 100644 index 0000000..2e0c2e2 --- /dev/null +++ b/3DPortraitGAN_pyramid/run_trigrid_gen.py @@ -0,0 +1,264 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm +import mrcfile + +import legacy + +from camera_utils import LookAtPoseSampler +from torch_utils import misc +import pickle +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): + # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle + voxel_origin = np.array(voxel_origin) - cube_length/2 + voxel_size = cube_length / (N - 1) + + overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) + samples = torch.zeros(N ** 3, 3) + + # transform first 3 columns + # to be the x, y, z index + samples[:, 2] = overall_index % N + samples[:, 1] = (overall_index.float() / N) % N + samples[:, 0] = ((overall_index.float() / N) / N) % N + + # transform first 3 columns + # to be the x, y, z coordinate + samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] + samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] + samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] + + num_samples = N ** 3 + + return samples.unsqueeze(0), voxel_origin, voxel_size + +#---------------------------------------------------------------------------- + +def gen_trigrids(G, output: str, latend_code_path, shuffle_seed=None, w_frames=60*4, kind='cubic', + grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, + cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), + **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + + + + camera_lookat_point = torch.tensor([0, 0.0649, 0], device=device) + # zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) + ws = torch.load(latend_code_path, map_location=device) + print('ws shape', ws.shape) + + cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, camera_lookat_point, radius=2.7, device=device) + focal_length = 6.5104166 # if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) + c = c.repeat(len(ws), 1) + + p = torch.zeros([len(ws), 6], device=device) + + + trigrids,ws = G.gen_planes(ws=ws, c=c[0:1], noise_mode='const') + print('trigrids shape: ') + save_files = {'ws': ws} + for k in trigrids.keys(): + print(k, trigrids[k].shape) + save_files[f'trigrids_{trigrids[k].shape[3]}'] = trigrids[k] + print('save as', f'trigrids_{k}') + # save trigrids to output pkl file + with open(output, 'wb') as f: + pickle.dump(save_files, f) + + print(save_files.keys()) + + + + + + + + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename',default='./models/model.pkl') +@click.option('--test_data_dir', help='Network pickle filename',default='../test_data') + +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) +@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True) +@click.option('--cfg', help='Config', type=click.Choice(['FFHQ', 'AFHQ', 'Shapenet']), required=False, metavar='STR', default='FFHQ', show_default=True) +@click.option('--image_mode', help='Image mode', type=click.Choice(['image', 'image_depth', 'image_raw']), required=False, metavar='STR', default='image', show_default=True) +@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) +@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) +@click.option('--inversion_name', type=str, required=True) + +def generate_images( + network_pkl: str, + test_data_dir: str, + shuffle_seed: Optional[int], + truncation_psi: float, + truncation_cutoff: int, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + reload_modules: bool, + cfg: str, + image_mode: str, + sampling_multiplier: float, + nrr: Optional[int], + inversion_name: str, +): + """Render a latent vector interpolation video. + + Examples: + + \b + # Render a 4x2 grid of interpolations for seeds 0 through 31. + python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + Animation length and seed keyframes: + + The animation length is either determined based on the --seeds value or explicitly + specified using the --num-keyframes option. + + When num keyframes is specified with --num-keyframes, the output video length + will be 'num_keyframes*w_frames' frames. + + If --num-keyframes is not specified, the number of seeds given with + --seeds must be divisible by grid size W*H (--grid). In this case the + output video length will be '# seeds/(w*h)*w_frames' frames. + """ + + + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) + G.rendering_kwargs['depth_resolution_importance'] = int(G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) + + G.rendering_kwargs['ray_start'] = 2.35 + + + + G.set_batch_size(1) + + from training.smpl_triplane import TriPlaneGenerator + + if True: + print("Reloading Modules!") + G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) + misc.copy_params_and_buffers(G, G_new, require_all=True) + G_new.neural_rendering_resolution = G.neural_rendering_resolution + G_new.rendering_kwargs = G.rendering_kwargs + G = G_new + + import glob + + for path in glob.glob(os.path.join(test_data_dir,f'*/samples_new_crop/{inversion_name}/*/inversion.pt')): + outdir = os.path.dirname(path) + latend_code = f'{outdir}/inversion.pt' + + if not os.path.exists(outdir): + os.makedirs(outdir, exist_ok=True) + if nrr is not None: G.neural_rendering_resolution = nrr + + if truncation_cutoff == 0: + truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways + if truncation_psi == 1.0: + truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff + + + output = os.path.join(outdir, f'inversion_trigrid.pkl') + if os.path.exists(output): + print(f'Already exists: {output}') + continue + gen_trigrids(G=G, output=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, + latend_code_path=latend_code, shuffle_seed=shuffle_seed, psi=truncation_psi, + truncation_cutoff=truncation_cutoff, cfg=cfg, image_mode=image_mode) + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/segmentation_example.py b/3DPortraitGAN_pyramid/segmentation_example.py new file mode 100644 index 0000000..6d8628f --- /dev/null +++ b/3DPortraitGAN_pyramid/segmentation_example.py @@ -0,0 +1,108 @@ +import torch +from PIL import Image +from torchvision.transforms import ToPILImage +import glob +import os +from torchvision.models.segmentation import deeplabv3_resnet101 +from torchvision import transforms, utils +from tqdm import tqdm +import tempfile +import dnnlib +from torch_utils import training_stats +from torch_utils import custom_ops + +from torch.utils.data import dataset + + +class LoadData(dataset.Dataset): + + def __init__(self, base_path): + super(LoadData, self).__init__() + #base_path = 'F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion/output/2023-10-28-with-inversion-initialization/samples_new_crop' + paths = sorted(glob.glob(f'{base_path}/aligned_images/*')) + os.makedirs(f'{base_path}/mask', exist_ok=True) + self.paths = paths + + def __getitem__(self,idx): + image_path =self.paths[idx] + image = Image.open(image_path) + # Define the preprocessing transformation + preprocess = transforms.Compose([ + transforms.ToTensor() + ]) + + # Apply the transformation to the image + input_tensor = preprocess(image) + + return input_tensor, image_path + + def __len__(self): + return len(self.paths) + + +def get_mask(model, batch, cid): + normalized_batch = transforms.functional.normalize( + batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + output = model(normalized_batch)['out'] + # sem_classes = [ + # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' + # ] + # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} + # cid = sem_class_to_idx['car'] + + normalized_masks = torch.nn.functional.softmax(output, dim=1) + + boolean_car_masks = (normalized_masks.argmax(1) == cid) + return boolean_car_masks.float() + + +def get_and_save_mask( device,base_path): + # data loder + batch_size = 8 + dataset = torch.utils.data.DataLoader( + dataset=LoadData(base_path), + batch_size=batch_size, + shuffle=False + ) + for input_tensor, image_paths in tqdm(dataset): + input_batch = input_tensor.to(device) # batxh, 3, 256, 256 + + # load segmentation net + seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to(device) + seg_net.requires_grad_(False) + seg_net.eval() + + # 15 means human mask + mask = get_mask(seg_net, input_batch, 15) + print(mask.shape) # 16, 256, 256 + + mask = mask.unsqueeze(1) # 16, 1, 256, 256 + + for i in range(mask.shape[0]): + # Squeeze the tensor to remove unnecessary dimensions and convert to PIL Image + mask0 = mask[i:i+1] + mask_squeezed = torch.squeeze(mask0) + mask_image = ToPILImage()(mask_squeezed) + image_path = image_paths[i] + # Save as PNG + mask_path = image_path.replace('aligned_images', 'mask') + # /home/zjucadjin/dataset/pexels-256-new/0000000053/0000053992.png + # mask_dir = mask_path[:-len('/0000053992.png')] + # os.makedirs(mask_dir, exist_ok=True) + mask_image.save(mask_path) + + +def run(rank,base_path): + rank = rank + device = torch.device('cuda', rank) + get_and_save_mask(device,base_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--image_path', type=str, required=True) + parser.add_argument('--mask_path', type=str, required=True) + run(0, parser.parse_args().base_path) \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/shape_utils.py b/3DPortraitGAN_pyramid/shape_utils.py new file mode 100644 index 0000000..e16f6cc --- /dev/null +++ b/3DPortraitGAN_pyramid/shape_utils.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + + +""" +Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.) + +Takes as input an .mrc file and extracts a mesh. + +Ex. + python shape_utils.py my_shape.mrc +Ex. + python shape_utils.py myshapes_directory --level=12 +""" + + +import time +import plyfile +import glob +import logging +import numpy as np +import os +import random +import torch +import torch.utils.data +import trimesh +import skimage.measure +import argparse +import mrcfile +from tqdm import tqdm + + +def convert_sdf_samples_to_ply( + numpy_3d_sdf_tensor, + voxel_grid_origin, + voxel_size, + ply_filename_out, + offset=None, + scale=None, + level=0.0 +): + """ + Convert sdf samples to .ply + :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) + :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid + :voxel_size: float, the size of the voxels + :ply_filename_out: string, path of the filename to save to + This function adapted from: https://github.com/RobotLocomotion/spartan + """ + start_time = time.time() + + verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) + # try: + verts, faces, normals, values = skimage.measure.marching_cubes( + numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3 + ) + # except: + # pass + + # transform from voxel coordinates to camera coordinates + # note x and y are flipped in the output of marching_cubes + mesh_points = np.zeros_like(verts) + mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] + mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] + mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] + + # apply additional offset and scale + if scale is not None: + mesh_points = mesh_points / scale + if offset is not None: + mesh_points = mesh_points - offset + + # try writing to the ply file + + num_verts = verts.shape[0] + num_faces = faces.shape[0] + + verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) + + for i in range(0, num_verts): + verts_tuple[i] = tuple(mesh_points[i, :]) + + faces_building = [] + for i in range(0, num_faces): + faces_building.append(((faces[i, :].tolist(),))) + faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) + + el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") + el_faces = plyfile.PlyElement.describe(faces_tuple, "face") + + ply_data = plyfile.PlyData([el_verts, el_faces]) + ply_data.write(ply_filename_out) + print(f"wrote to {ply_filename_out}") + + +def convert_mrc(input_filename, output_filename, isosurface_level=1): + with mrcfile.open(input_filename) as mrc: + convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level) + +if __name__ == '__main__': + start_time = time.time() + parser = argparse.ArgumentParser() + parser.add_argument('input_mrc_path') + parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes") + args = parser.parse_args() + + if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply': + output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply' + convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1) + + print(f"{time.time() - start_time:02f} s") + else: + assert os.path.isdir(args.input_mrc_path) + + for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))): + output_obj_path = mrc_path.split('.mrc')[0] + '.ply' + convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level) \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/torch_utils/__init__.py b/3DPortraitGAN_pyramid/torch_utils/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/3DPortraitGAN_pyramid/torch_utils/custom_ops.py b/3DPortraitGAN_pyramid/torch_utils/custom_ops.py new file mode 100644 index 0000000..ed2524f --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/custom_ops.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/misc.py b/3DPortraitGAN_pyramid/torch_utils/misc.py new file mode 100644 index 0000000..2fc93df --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/misc.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + else: + print(f'{name} is not in src_module, init it using random valua!') + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/__init__.py b/3DPortraitGAN_pyramid/torch_utils/ops/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cpp b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cpp new file mode 100644 index 0000000..ee6f6d0 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cu b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cu new file mode 100644 index 0000000..71ca390 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.cu @@ -0,0 +1,177 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.h b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.h new file mode 100644 index 0000000..8994bfb --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.h @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.py b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.py new file mode 100644 index 0000000..b1f4d39 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/bias_act.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import numpy as np +import torch +import dnnlib + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_gradfix.py b/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000..9a177cc --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input, weight) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, weight): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1] + + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_resample.py b/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000..d46f4dd --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cpp b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cpp new file mode 100644 index 0000000..4f55466 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cpp @@ -0,0 +1,304 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "filtered_lrelu.h" + +//------------------------------------------------------------------------ + +static std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, + int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) + { + // No kernel found - return empty tensors and indicate missing kernel with return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) + { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = { 0 }; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] + { + if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write modes. + if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + } + }); + TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) + { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } + else + { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) + { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] + { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); + return so; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. + m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. +} + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cu b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cu new file mode 100644 index 0000000..aaac954 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.cu @@ -0,0 +1,1288 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "filtered_lrelu.h" +#include + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ + MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template struct InternalType; +template <> struct InternalType +{ + typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } + __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ + ((B)==2) ? ((int)((A)+1) >> 1) : \ + ((B)==4) ? ((int)((A)+3) >> 2) : \ + (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers of two. +template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) +{ + if ((N & (N-1)) && N <= 256) + y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i/N; + + x = i - y*N; +} + +// Type cast stride before reading it. +template __device__ __forceinline__ T get_stride(const int64_t& x) +{ + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. +__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) +{ + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) + { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer for main kernel. +template static cudaError_t copy_filters(cudaStream_t stream) +{ + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) +{ + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); + static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); + static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); + static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); + static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); + static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : + (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUFD) ? szIn : + -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : + (filterMode == MODE_FUSD) ? szUpXY : + (filterMode == MODE_SUFD) ? szUpX : + (filterMode == MODE_FUFD) ? szUpXY : + -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) + { + // Allocate shared memory arrays here. + __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } + else + { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] + if (filterMode == MODE_SUSD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } + else if (filterMode == MODE_FUSD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } + else if (filterMode == MODE_SUFD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } + else if (filterMode == MODE_FUFD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + + // Inner tile loop. + #pragma unroll 1 + for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) + { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on first tile. + if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); + #pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) + { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; + + bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); + if (!skip) + s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) + { + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInX == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInX == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + s_tileUpX[dst+2] = v.z; + s_tileUpX[dst+3] = v.w; + } + } + else if (up == 2) + { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) + { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInY == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInY == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } + if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } + else if (up == 2) + { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) + { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) + s_tileUpXY[dst + tileUpW] = v.y; + } + else + { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) + { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } + else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) + { + // Full upsampling filter. + + if (up == 2) + { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + + #define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) \ + { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 0) } + if (tap0y == 0 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 1) } + if (tap0y == 1 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 0) } + if (tap0y == 1 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 1) } + + #undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) + { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } + else if (up == 1) + { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } + else + { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + else + { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } + else if (signRead) + { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) + { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) + { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) + { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + s_tileDownX[idx+2] = v.z; + s_tileDownX[idx+3] = v.w; + } + } + else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) + { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + } + } + else + { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) + { + // Full downsampling filter. + if (down == 2) + { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int sy = 0; sy < fdSize; sy++) + #pragma unroll + for (int sx = 0; sx < fdSize; sx++) + { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) + { + index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; + } + } + } + else if (down == 1 && !downInline) + { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) + break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. +// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) + { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) + { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) + { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) + { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } + else if (signRead) + { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) + { + uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } + else + { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) + v *= p.slope; + if (fabsf(v) > p.clamp) + v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template void* choose_filtered_lrelu_act_kernel(void) +{ + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) +{ + filtered_lrelu_kernel_spec s = { 0 }; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ + { \ + static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ + static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*)filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. + // Kernels that use more shared memory must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 + + #undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.h b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.h new file mode 100644 index 0000000..f2bfd1d --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.h @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params +{ + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params +{ + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec +{ + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template void* choose_filtered_lrelu_act_kernel(void); +template cudaError_t copy_filters(cudaStream_t stream); + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.py b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.py new file mode 100644 index 0000000..2047b7e --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import warnings + +from .. import custom_ops +from .. import misc +from . import upfirdn2d +from . import bias_act + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='filtered_lrelu_plugin', + sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], + headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + +#---------------------------------------------------------------------------- + +def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): + r"""Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if provided (`b`). + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. + + 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking + it so that the footprint of all output pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + fu: Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + fd: Float32 downsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The length of vector must must match the channel dimension of `x`. + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor. (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + gain: Overall scaling factor for signal magnitude (default: sqrt(2)). + slope: Slope on the negative side of leaky ReLU (default: 0.2). + clamp: Maximum magnitude for leaky ReLU output (default: None). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) + return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using + existing `upfirdn2n()` and `bias_act()` ops. + """ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + fu_w, fu_h = _get_filter_size(fu) + fd_w, fd_h = _get_filter_size(fd) + if b is not None: + assert isinstance(b, torch.Tensor) and b.dtype == x.dtype + misc.assert_shape(b, [x.shape[1]]) + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = x.shape + in_dtype = x.dtype + out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + # Compute using existing ops. + x = bias_act.bias_act(x=x, b=b) # Apply bias. + x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. + x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Check output shape & dtype. + misc.assert_shape(x, [batch_size, channels, out_h, out_w]) + assert x.dtype == in_dtype + return x + +#---------------------------------------------------------------------------- + +_filtered_lrelu_cuda_cache = dict() + +def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). + if fu is None: + fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if fd is None: + fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert 1 <= fu.ndim <= 2 + assert 1 <= fd.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. + if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: + fu = fu.square()[None] + if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: + fd = fd.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if b is None: + b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) + + # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. + strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if x.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): + warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) + y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because + # only the bit-packed sign tensor is retained for gradient computation. + if return_code < 0: + warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) + + y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. + y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(fu, fd, (si if si.numel() else so)) + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + fu, fd, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None; assert not ctx.needs_input_grad[1] + dfd = None; assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None; assert not ctx.needs_input_grad[4] + dsx = None; assert not ctx.needs_input_grad[5] + dsy = None; assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, + xw * up - yw * down + px0 - (up - 1), + (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up ** 2) / (down ** 2) + ff = (not flip_filter) + sx = sx - (fu.shape[-1] - 1) + px0 + sy = sy - (fu.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_ns.cu b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_ns.cu new file mode 100644 index 0000000..8a3eae4 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_ns.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for no signs mode (no gradients required). + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_rd.cu b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_rd.cu new file mode 100644 index 0000000..3cd43ec --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_rd.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign read mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_wr.cu b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_wr.cu new file mode 100644 index 0000000..bc2fa06 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/filtered_lrelu_wr.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign write mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/fma.py b/3DPortraitGAN_pyramid/torch_utils/ops/fma.py new file mode 100644 index 0000000..5458116 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/fma.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/grid_sample_gradfix.py b/3DPortraitGAN_pyramid/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000..35d9472 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cpp b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000..c1769c3 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cu b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000..7d182d7 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,388 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.h b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.h new file mode 100644 index 0000000..d5de893 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.py b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000..5d63471 --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,391 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import torch + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/persistence.py b/3DPortraitGAN_pyramid/torch_utils/persistence.py new file mode 100644 index 0000000..1abf9cb --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/persistence.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/torch_utils/training_stats.py b/3DPortraitGAN_pyramid/torch_utils/training_stats.py new file mode 100644 index 0000000..636dd7f --- /dev/null +++ b/3DPortraitGAN_pyramid/torch_utils/training_stats.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/train.py b/3DPortraitGAN_pyramid/train.py new file mode 100644 index 0000000..d06d835 --- /dev/null +++ b/3DPortraitGAN_pyramid/train.py @@ -0,0 +1,501 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Train a GAN using the techniques described in the paper +"Efficient Geometry-aware 3D Generative Adversarial Networks." + +Code adapted from +"Alias-Free Generative Adversarial Networks".""" + +import os +import click +import re +import json +import tempfile +import torch + +import dnnlib +from training import training_loop +from metrics import metric_main +from torch_utils import training_stats +from torch_utils import custom_ops + +#---------------------------------------------------------------------------- + +def subprocess_fn(rank, c, temp_dir): + dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) + + # Init torch.distributed. + if c.num_gpus > 1: + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + if rank != 0: + custom_ops.verbosity = 'none' + + # Execute training loop. + training_loop.training_loop(rank=rank, **c) + +#---------------------------------------------------------------------------- + +def launch_training(c, desc, outdir, dry_run): + dnnlib.util.Logger(should_flush=True) + + # Pick output directory. + prev_run_dirs = [] + if os.path.isdir(outdir): + prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] + prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] + prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] + cur_run_id = max(prev_run_ids, default=-1) + 1 + c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') + assert not os.path.exists(c.run_dir) + + # Print options. + print() + print('Training options:') + print(json.dumps(c, indent=2)) + print() + print(f'Output directory: {c.run_dir}') + print(f'Number of GPUs: {c.num_gpus}') + print(f'Batch size: {c.batch_size} images') + print(f'Training duration: {c.total_kimg} kimg') + print(f'Dataset path (img): {c.training_set_kwargs.img_path}') + print(f'Dataset path (seg): {c.training_set_kwargs.seg_path}') + print(f'Dataset size: {c.training_set_kwargs.max_size} images') + print(f'Dataset resolution: {c.training_set_kwargs.resolution}') + print(f'Dataset labels: {c.training_set_kwargs.use_labels}') + print(f'Dataset x-flips: {c.training_set_kwargs.xflip}') + print() + + # Dry run? + if dry_run: + print('Dry run; exiting.') + return + + # Create output directory. + print('Creating output directory...') + os.makedirs(c.run_dir) + with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: + json.dump(c, f, indent=2) + + # Launch processes. + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + if c.num_gpus == 1: + subprocess_fn(rank=0, c=c, temp_dir=temp_dir) + else: + torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) + +#---------------------------------------------------------------------------- + +def init_dataset_kwargs(data,seg_data, data_rebalance,data_rebalance_idx_file,back_repeat): + try: + dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MaskLabeledDataset', + img_path=data, + seg_path = seg_data, + back_repeat = back_repeat, + use_labels=True, max_size=None, xflip=True, + data_rebalance=data_rebalance,data_rebalance_idx_file=data_rebalance_idx_file) + dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset. + dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution. + dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels. + dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size. + return dataset_kwargs, dataset_obj.name + except IOError as err: + raise click.ClickException(f'--data: {err}') + +#---------------------------------------------------------------------------- + +def parse_comma_separated_list(s): + if isinstance(s, list): + return s + if s is None or s.lower() == 'none' or s == '': + return [] + return s.split(',') + +#---------------------------------------------------------------------------- + +@click.command() + +# Required. +@click.option('--outdir', help='Where to save the results', metavar='DIR', required=True) +@click.option('--cfg', help='Base configuration', type=str, required=True) +@click.option('--data', help='Training data', metavar='[ZIP|DIR]', type=str, required=True) + +@click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) +@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) +@click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True) + +# Optional features. +@click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=True, show_default=True) +@click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='noaug', show_default=True) +@click.option('--resume', help='Resume from given network pickle', metavar='[PATH|URL]', type=str) +@click.option('--resume_kimg', help='Resume from given kimg', metavar='INT', type=int, default=0) +@click.option('--freezed', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +@click.option('--data_rebalance', help='Enable dataset rebalance', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--data_rebalance_idx_file', help='Enable dataset rebalance', metavar='BOOL', type=str, required=False,default = None) + +# Misc hyperparameters. +@click.option('--p', help='Probability for --aug=fixed', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True) +@click.option('--target', help='Target value for --aug=ada', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.6, show_default=True) +@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) +@click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=18432, show_default=True) +@click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=144, show_default=True) +@click.option('--glr', help='G learning rate [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0)) +@click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True) +@click.option('--map-depth', help='Mapping network depth [default: varies]', metavar='INT', type=click.IntRange(min=1), default=2, show_default=True) +@click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) + +# Misc settings. +@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) +@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) +@click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=25000, show_default=True) +@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=4, show_default=True) +@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) +@click.option('--image-snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) +@click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +# @click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) +@click.option('-n','--dry-run', help='Print training options and exit', is_flag=True) + +# @click.option('--sr_module', help='Superresolution module', metavar='STR', type=str, required=True) +@click.option('--neural_rendering_resolution_initial', help='Resolution to render at', metavar='INT', type=click.IntRange(min=1), default=64, required=False) +@click.option('--neural_rendering_resolution_final', help='Final resolution to render at, if blending', metavar='INT', type=click.IntRange(min=1), required=False, default=None) +@click.option('--neural_rendering_resolution_fade_kimg', help='Kimg to blend resolution over', metavar='INT', type=click.IntRange(min=0), required=False, default=1000, show_default=True) + +@click.option('--blur_fade_kimg', help='Blur over how many', metavar='INT', type=click.IntRange(min=1), required=False, default=200) +@click.option('--gen_pose_cond', help='If true, enable generator pose conditioning.', metavar='BOOL', type=bool, required=False, default=False) +@click.option('--c-scale', help='Scale factor for generator pose conditioning.', metavar='FLOAT', type=click.FloatRange(min=0), required=False, default=1) +@click.option('--c-noise', help='Add noise for generator pose conditioning.', metavar='FLOAT', type=click.FloatRange(min=0), required=False, default=0) +@click.option('--gpc_reg_prob', help='Strength of swapping regularization. None means no generator pose conditioning, i.e. condition with zeros.', metavar='FLOAT', type=click.FloatRange(min=0), required=False, default=0.5) +@click.option('--gpc_reg_fade_kimg', help='Length of swapping prob fade', metavar='INT', type=click.IntRange(min=0), required=False, default=1000) + + + + +@click.option('--disc_c_noise', help='Strength of discriminator pose conditioning regularization, in standard deviations.', metavar='FLOAT', type=click.FloatRange(min=0), required=False, default=0) +@click.option('--sr_noise_mode', help='Type of noise for superresolution', metavar='STR', type=click.Choice(['random', 'none']), required=False, default='none') +@click.option('--resume_blur', help='Enable to blur even on resume', metavar='BOOL', type=bool, required=False, default=False) +@click.option('--sr_num_fp16_res', help='Number of fp16 layers in superresolution', metavar='INT', type=click.IntRange(min=0), default=4, required=False, show_default=True) +@click.option('--g_num_fp16_res', help='Number of fp16 layers in generator', metavar='INT', type=click.IntRange(min=0), default=0, required=False, show_default=True) +@click.option('--d_num_fp16_res', help='Number of fp16 layers in discriminator', metavar='INT', type=click.IntRange(min=0), default=4, required=False, show_default=True) +@click.option('--sr_first_cutoff', help='First cutoff for AF superresolution', metavar='INT', type=click.IntRange(min=2), default=2, required=False, show_default=True) +@click.option('--sr_first_stopband', help='First cutoff for AF superresolution', metavar='FLOAT', type=click.FloatRange(min=2), default=2**2.1, required=False, show_default=True) +@click.option('--style_mixing_prob', help='Style-mixing regularization probability for training.', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0, required=False, show_default=True) +@click.option('--sr-module', help='Superresolution module override', metavar='STR', type=str, required=False, default=None) +@click.option('--density_reg', help='Density regularization strength.', metavar='FLOAT', type=click.FloatRange(min=0), default=0.25, required=False, show_default=True) +@click.option('--density_reg_every', help='lazy density reg', metavar='int', type=click.FloatRange(min=1), default=4, required=False, show_default=True) +@click.option('--density_reg_p_dist', help='density regularization strength.', metavar='FLOAT', type=click.FloatRange(min=0), default=0.004, required=False, show_default=True) +@click.option('--reg_type', help='Type of regularization', metavar='STR', type=click.Choice(['l1', 'l1-alt', 'monotonic-detach', 'monotonic-fixed', 'total-variation']), required=False, default='l1') +@click.option('--decoder_lr_mul', help='decoder learning rate multiplier.', metavar='FLOAT', type=click.FloatRange(min=0), default=1, required=False, show_default=True) + +@click.option('--thickness', help='sample thickness around head mesh', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.25, show_default=True) # smpl head stride is ~ 0.2 + + + + +@click.option('--pose_loss_weight', help='sample thickness around head mesh', metavar='FLOAT', type=click.FloatRange(min=0), default=1.0, show_default=True) # smpl head stride is ~ 0.2 +@click.option('--input_pose_params_reg_loss_weight', help='sample thickness around head mesh', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) # smpl head stride is ~ 0.2 +@click.option('--input_pose_params_reg_loss_kimg', help='sample thickness around head mesh', metavar='INT', type=click.FloatRange(min=0), default=0, show_default=True) # smpl head stride is ~ 0.2 + + +@click.option('--explicitly_symmetry', help='Enable to blur even on resume', metavar='BOOL', type=bool, required=False, default=False) + + +@click.option('--train_g_pose_branch', help='Enable to blur even on resume', metavar='BOOL', type=bool, required=False, default=True) + +@click.option('--metric_pose_sample_mode', help='Type of metric_pose_sample ', metavar='STR', type=click.Choice(['D_predict', 'G_predict']), required=False, default='G_predict') + +# panohead +@click.option('--seg_channels', help='Channels of masks for discriminator.', metavar='INT', type=click.IntRange(min=1), default=1, required=False, show_default=True) +@click.option('--decoder_activation', help='Activation function for decoder.', metavar='STR', type=click.Choice(['sigmoid', 'lrelu', 'none']), default="sigmoid", required=False, show_default=True) +@click.option('--use_torgb_raw', help='Use ToRGB for raw image output.', metavar='BOOL', type=bool, default=False, required=False, show_default=True) +@click.option('--use_background', help='Use separate background generator.', metavar='BOOL', type=bool, default=True, required=False, show_default=True) +@click.option('--bcg_reg_prob', help='Swapping probability of bacgkround w code.', metavar='FLOAT', type=click.FloatRange(min=0), default=0, required=False, show_default=True) +@click.option('--seg_data', help='Training data', metavar='[ZIP|DIR]', type=str, required=True) +@click.option('--gamma_seg', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True) +@click.option('--density_noise_fade_kimg', help='Kimg to add density noise.', metavar='INT', type=click.IntRange(min=0), default=0, required=False, show_default=True) +@click.option('--triplane_depth', help='Grid depth of each of tri-plane', metavar='INT', type=click.IntRange(min=1), default=1, required=False, show_default=True) + +@click.option('--back_repeat', help='Repeat abs [max(90, min_yaw), max_yaw] images how many times', metavar='INT', type=click.IntRange(min=1), default=1, required=False, show_default=True) +@click.option('--radius_scale', help='radius scale ratio.', metavar='FLOAT', type=click.FloatRange(min=0.0), default=0.7) + + + +def main(**kwargs): + """Train a GAN using the techniques described in the paper + "Alias-Free Generative Adversarial Networks". + + Examples: + + \b + # Train StyleGAN3-T for AFHQv2 using 8 GPUs. + python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \\ + --gpus=8 --batch=32 --gamma=8.2 --mirror=1 + + \b + # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. + python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \\ + --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \\ + --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl + + \b + # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. + python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \\ + --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug + """ + + # Initialize config. + print('>>>>>>>>>>>>>> kwargs:', kwargs) + opts = dnnlib.EasyDict(kwargs) # Command line arguments. + c = dnnlib.EasyDict() # Main config dict. + c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict()) + c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) + c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) + c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) + c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss') + c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) + + + if opts.data_rebalance: + assert opts.data_rebalance_idx_file is not None + + if opts.data_rebalance: + raise NotImplementedError('data_rebalance is not implemented yet') + + # Training set. + c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data, seg_data =opts.seg_data, data_rebalance = opts.data_rebalance,data_rebalance_idx_file = opts.data_rebalance_idx_file,back_repeat = opts.back_repeat) + if opts.cond and not c.training_set_kwargs.use_labels: + raise click.ClickException('--cond=True requires labels specified in dataset.json') + c.training_set_kwargs.use_labels = opts.cond + c.training_set_kwargs.xflip = opts.mirror + c.training_set_kwargs.data_rebalance = opts.data_rebalance + if opts.data_rebalance: + c.training_set_kwargs.data_rebalance_idx_file = opts.data_rebalance_idx_file + + + # Hyperparameters & settings. + c.num_gpus = opts.gpus + c.batch_size = opts.batch + c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus + c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase + c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax + c.G_kwargs.mapping_kwargs.num_layers = opts.map_depth + c.G_kwargs.batch_size = c.batch_gpu + c.D_kwargs.block_kwargs.freeze_layers = opts.freezed + c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group + c.loss_kwargs.r1_gamma = opts.gamma + c.loss_kwargs.r1_gamma_seg = opts.gamma_seg + c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr + c.D_opt_kwargs.lr = opts.dlr + c.metrics = opts.metrics + c.total_kimg = opts.kimg + c.kimg_per_tick = opts.tick + c.network_snapshot_ticks = opts.snap + c.image_snapshot_ticks = opts.image_snap + c.random_seed = c.training_set_kwargs.random_seed = opts.seed + c.data_loader_kwargs.num_workers = opts.workers + + # Sanity checks. + if c.batch_size % c.num_gpus != 0: + raise click.ClickException('--batch must be a multiple of --gpus') + if c.batch_size % (c.num_gpus * c.batch_gpu) != 0: + raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu') + if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size: + raise click.ClickException('--batch-gpu cannot be smaller than --mbstd') + if any(not metric_main.is_valid_metric(metric) for metric in c.metrics): + raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) + + # Base configuration. + c.ema_kimg = c.batch_size * 10 / 32 + c.G_kwargs.class_name = 'training.smpl_triplane.TriPlaneGenerator' + c.D_kwargs.class_name = 'training.dual_discriminator.PoseShapeAwareDualDiscriminator' + + c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions. + c.loss_kwargs.filter_mode = 'antialiased' # Filter mode for raw images ['antialiased', 'none', float [0-1]] + c.D_kwargs.disc_c_noise = opts.disc_c_noise # Regularization for discriminator pose conditioning + + if c.training_set_kwargs.resolution == 512: + sr_module = 'training.superresolution.SuperresolutionHybrid8XDC' + elif c.training_set_kwargs.resolution == 256: + sr_module = 'training.superresolution.SuperresolutionHybrid4X' + elif c.training_set_kwargs.resolution == 128: + sr_module = 'training.superresolution.SuperresolutionHybrid2X' + else: + assert False, f"Unsupported resolution {c.training_set_kwargs.resolution}; make a new superresolution module" + + if opts.sr_module != None: + sr_module = opts.sr_module + + rendering_options = { + 'image_resolution': c.training_set_kwargs.resolution, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'superresolution_module': sr_module, + 'c_gen_conditioning_zero': not opts.gen_pose_cond, # if true, fill generator pose conditioning label with dummy zero vector + 'gpc_reg_prob': opts.gpc_reg_prob if opts.gen_pose_cond else None, + 'decoder_activation': opts.decoder_activation, # activation function for decoder + 'use_torgb_raw': opts.use_torgb_raw, # use ToRGB layer for raw image output + 'use_background': opts.use_background, # use separate background generator + 'triplane_depth': opts.triplane_depth, # grid depth of each of tri-plane + 'c_scale': opts.c_scale, # mutliplier for generator pose conditioning label + 'superresolution_noise_mode': opts.sr_noise_mode, # [random or none], whether to inject pixel noise into super-resolution layers + 'density_reg': opts.density_reg, # strength of density regularization + 'density_reg_p_dist': opts.density_reg_p_dist, # distance at which to sample perturbed points for density regularization + 'reg_type': opts.reg_type, # for experimenting with variations on density regularization + 'decoder_lr_mul': opts.decoder_lr_mul, # learning rate multiplier for decoder + 'sr_antialias': True, + 'radius_scale': opts.radius_scale, + + } + + # if opts.cfg == 'ffhq': + # rendering_options.update({ + # 'depth_resolution': 48, # number of uniform samples to take per ray. + # 'depth_resolution_importance': 48, # number of importance samples to take per ray. + # 'ray_start': 2.25, # near point along each ray to start taking samples. + # 'ray_end': 3.3, # far point along each ray to stop taking samples. + # 'box_warp': 1, # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5]. + # 'avg_camera_radius': 2.7, # used only in the visualizer to specify camera orbit radius. + # 'avg_camera_pivot': [0, 0, 0.2], # used only in the visualizer to control center of camera rotation. + # }) + # elif opts.cfg == 'afhq': + # rendering_options.update({ + # 'depth_resolution': 48, + # 'depth_resolution_importance': 48, + # 'ray_start': 2.25, + # 'ray_end': 3.3, + # 'box_warp': 1, + # 'avg_camera_radius': 2.7, + # 'avg_camera_pivot': [0, 0, -0.06], + # }) + # elif opts.cfg == 'shapenet': + # rendering_options.update({ + # 'depth_resolution': 64, + # 'depth_resolution_importance': 64, + # 'ray_start': 0.1, + # 'ray_end': 2.6, + # 'box_warp': 1.6, + # 'white_back': True, + # 'avg_camera_radius': 1.7, + # 'avg_camera_pivot': [0, 0, 0], + # }) + # el + if opts.cfg == 'full-head': + rendering_options.update({ + 'depth_resolution': 48, # number of uniform samples to take per ray. + 'depth_resolution_importance': 48, # number of importance samples to take per ray. + # 'ray_start': 2.25 * opts.radius_scale, # near point along each ray to start taking samples. + # 'ray_end': 3.3 * opts.radius_scale, # far point along each ray to stop taking samples. + 'ray_start': 2.25 + (2.7-2.25) * (1- opts.radius_scale), # near point along each ray to start taking samples. + 'ray_end': (3.3-2.7) * opts.radius_scale + 2.7 , # far point along each ray to stop taking samples. + 'box_warp': 1* opts.radius_scale, + #'c_gen_conditioning_zero': True, # disable camera condition on mapping network + + }) + else: + assert False, "Need to specify config" + + + + if opts.density_reg > 0: + c.G_reg_interval = opts.density_reg_every + c.G_kwargs.rendering_kwargs = rendering_options + c.G_kwargs.num_fp16_res = 0 + c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator. + c.loss_kwargs.blur_fade_kimg = c.batch_size * opts.blur_fade_kimg / 32 # Fade out the blur during the first N kimg. + + c.loss_kwargs.density_noise_fade_kimg = opts.density_noise_fade_kimg + c.loss_kwargs.gpc_reg_prob = opts.gpc_reg_prob if opts.gen_pose_cond else None + c.loss_kwargs.gpc_reg_fade_kimg = opts.gpc_reg_fade_kimg + c.loss_kwargs.bcg_reg_prob = opts.bcg_reg_prob + c.loss_kwargs.dual_discrimination = True + c.loss_kwargs.neural_rendering_resolution_initial = opts.neural_rendering_resolution_initial + c.loss_kwargs.neural_rendering_resolution_final = opts.neural_rendering_resolution_final + c.loss_kwargs.neural_rendering_resolution_fade_kimg = opts.neural_rendering_resolution_fade_kimg + c.G_kwargs.sr_num_fp16_res = opts.sr_num_fp16_res + + c.G_kwargs.sr_kwargs = dnnlib.EasyDict(channel_base=opts.cbase, channel_max=opts.cmax, fused_modconv_default='inference_only') + + c.G_kwargs.thickness = opts.thickness + c.loss_kwargs.style_mixing_prob = opts.style_mixing_prob + c.loss_kwargs.thickness = opts.thickness + + c.metric_pose_sample_mode = opts.metric_pose_sample_mode + + + c.loss_kwargs.pose_loss_weight = opts.pose_loss_weight + + c.loss_kwargs.input_pose_params_reg_loss_weight = opts.input_pose_params_reg_loss_weight + c.loss_kwargs.input_pose_params_reg_loss_kimg = opts.input_pose_params_reg_loss_kimg + + c.train_g_pose_branch = opts.train_g_pose_branch + + c.G_kwargs.explicitly_symmetry = opts.explicitly_symmetry + c.D_kwargs.explicitly_symmetry = opts.explicitly_symmetry + + + # Augmentation. + if opts.aug != 'noaug': + c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1) + if opts.aug == 'ada': + c.ada_target = opts.target + if opts.aug == 'fixed': + c.augment_p = opts.p + + # Resume. + if opts.resume is not None: + c.resume_pkl = opts.resume + c.ada_kimg = 100 # Make ADA react faster at the beginning. + c.ema_rampup = None # Disable EMA rampup. + if not opts.resume_blur: + c.loss_kwargs.blur_init_sigma = 0 # Disable blur rampup. + c.loss_kwargs.gpc_reg_fade_kimg = 0 # Disable swapping rampup + c.resume_kimg = opts.resume_kimg + # Performance-related toggles. + # if opts.fp32: + # c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0 + # c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None + c.G_kwargs.num_fp16_res = opts.g_num_fp16_res + c.G_kwargs.conv_clamp = 256 if opts.g_num_fp16_res > 0 else None + c.D_kwargs.num_fp16_res = opts.d_num_fp16_res + c.D_kwargs.conv_clamp = 256 if opts.d_num_fp16_res > 0 else None + + c.D_kwargs.seg_channels = opts.seg_channels + + if opts.nobench: + c.cudnn_benchmark = False + + # Description string. + desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma}' + if opts.desc is not None: + desc += f'-{opts.desc}' + + # Launch. + launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/training/__init__.py b/3DPortraitGAN_pyramid/training/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/3DPortraitGAN_pyramid/training/aligned_smpl.py b/3DPortraitGAN_pyramid/training/aligned_smpl.py new file mode 100644 index 0000000..c635560 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/aligned_smpl.py @@ -0,0 +1,450 @@ + +import os.path as osp + +import numpy as np +import torch +from torch_utils import misc + +import trimesh +import pickle + + +import os +# os.environ["PYOPENGL_PLATFORM"] = "egl" +# check if on a Linux machine +if os.name == 'posix': # Linux + os.environ["PYOPENGL_PLATFORM"] = "osmesa" + +# os.environ["PYOPENGL_PLATFORM"] = "osmesa" +import pyrender + +class AlignedSMPL(torch.nn.Module): + def __init__(self, model,batch_size): + super().__init__() + self.batch_size = batch_size + smpl_joint_regressor = torch.from_numpy( + np.load('transfer_data/smpl_joint_regressor.npy')).float().cuda().contiguous() + self.register_buffer('smpl_joint_regressor', smpl_joint_regressor) + + self.model = model + faces = torch.from_numpy(self.model.faces.astype(np.int32)).cuda().long().contiguous() + self.register_buffer('faces', faces) + + + def set_model(self, model): + self.model = model + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def get_align_coordinate(self, vertices): + # 30 x 6890 + batch_size = vertices.shape[0] + smpl_joints = torch.bmm(self.smpl_joint_regressor[None, :, :].repeat(batch_size, 1, 1), vertices) + align_joint_coordinate = smpl_joints[:,12, None, :] # Neck + return align_joint_coordinate + + def render_mesh(self, img, mesh, face, cam_param, color=(1.0, 1.0, 0.9, 1.0), cam_pose=None): + # mesh + mesh = trimesh.Trimesh(mesh, face) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=color) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + focal, princpt = cam_param['focal'], cam_param['princpt'] + camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + + if cam_pose is not None: + scene.add(camera, pose=cam_pose) + else: + scene.add(camera) + # scene.add(camera) + # print('camera pose in scene ', scene.get_pose(scene._main_camera_node)) + # renderer + renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) + + # light + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) + # light_pose = np.eye(4) + # light_pose[:3, 3] = np.array([0, -1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([0, 1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([1, 1, 2]) + # scene.add(light, pose=light_pose) + + light_pose = np.eye(4) + light_pose[:3, 3] = np.array([0, 0, -1]) + scene.add(light, pose=light_pose) + + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + light_pose[:3, 3] = np.array([1, 1, -4]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([-1, 0, -1]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([0.2469, 1.8828, -2.4473]) + scene.add(light, pose=light_pose) + + # render + rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + rgb = rgb[:, :, :3].astype(np.float32) + valid_mask = (depth > 0)[:, :, None] + + # save to image + img = rgb * valid_mask + img * (1 - valid_mask) + return img.astype(np.uint8) + + def render_depth(self, img, mesh, face, cam_param, color=(1.0, 1.0, 0.9, 1.0), cam_pose=None): + # mesh + mesh = trimesh.Trimesh(mesh, face) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=color) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + focal, princpt = cam_param['focal'], cam_param['princpt'] + camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + + if cam_pose is not None: + scene.add(camera, pose=cam_pose) + else: + scene.add(camera) + # scene.add(camera) + # print('camera pose in scene ', scene.get_pose(scene._main_camera_node)) + # renderer + renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) + + # render + rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + #rgb = rgb[:, :, :3].astype(np.float32) + valid_mask = (depth > 0)[:, :, None] + + # save to image + depth = depth * valid_mask + img * (1 - valid_mask) + return depth.astype(np.uint8) + + + def get_projected_vertex(self, mesh, world2screen_matrix): + # mesh = np.concatenate([mesh, np.ones((mesh.shape[0], 1))], axis=1) # N x 4 + mesh = torch.cat([mesh, torch.ones((mesh.shape[0], 1)).to(mesh.device)], dim=1) # N x 4 + points_image = world2screen_matrix @ mesh.T # 4,N + points_image = points_image[:3, :] # 3,N + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + + + def generate_shaped_smpl(self, betas, scale, transl): + if betas is not None: + raise NotImplementedError + else: + betas = None + if scale is not None: + raise NotImplementedError + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if transl is not None: + raise NotImplementedError + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + + # body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + # apply beta, alignment, translation and scale + shaed_output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=None, + return_shaped=False + ) + vertices_no_pose = shaed_output.vertices + joints_no_pose = shaed_output.joints + + + align_joint_coordinate = self.get_align_coordinate(vertices_no_pose) # B,1,3 + vertices_no_pose -= align_joint_coordinate + joints_no_pose -= align_joint_coordinate + + vertices_no_pose += transl.view(self.batch_size, 1, 3) + joints_no_pose += transl.view(self.batch_size, 1, 3) + + vertices_no_pose *= scale.view(self.batch_size, 1, 1) + joints_no_pose *= scale.view(self.batch_size, 1, 1) + + nose_2d = joints_no_pose[:,86:90,:] # B, 4, 3 + eye_right_2d = joints_no_pose[:,95: 101,:] # B, 6, 3 + eye_left_2d = joints_no_pose[:,101: 107,:] # B, 6, 3 + + # points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + face_points = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=1) # B, 16, 3 + + #transformation_matrix = self.compute_transformation_matrix(face_points) + + res = { + 'vertices': vertices_no_pose, + 'align_joint_coordinate': align_joint_coordinate, + 'face_points': face_points, + } + return res + + def generate_posed_smpl(self, betas, scale, transl, body_pose, align_joint_coordinate): + batch_size = body_pose.shape[0] + if betas is not None: + raise NotImplementedError + else: + betas = None + if scale is not None: + raise NotImplementedError + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if transl is not None: + raise NotImplementedError + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + misc.assert_shape(body_pose, [self.batch_size, 6]) + + # apply beta, alignment, translation and scale + + # apply beta, pose, alignment, translation and scale + # mask pose except 11 and 14 + body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + body_pose_fill[:, 11, :] = body_pose[:, :3] + body_pose_fill[:, 14, :] = body_pose[:, 3:] + + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + + + body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + + output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=body_pose_fill, + return_shaped=True + ) + vertices = output.vertices + joints = output.joints + + # align vertices and joints + vertices -= align_joint_coordinate + joints -= align_joint_coordinate + + # additional translation + vertices += transl.view(self.batch_size, 1, 3) + joints += transl.view(self.batch_size, 1, 3) + + # additional scale + vertices *= scale.view(self.batch_size, 1, 1) + joints *= scale.view(self.batch_size, 1, 1) + + nose_2d = joints[:, 86:90, :] # B, 4, 3 + eye_right_2d = joints[:, 95: 101, :] # B, 6, 3 + eye_left_2d = joints[:, 101: 107, :] # B, 6, 3 + + # points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + face_points = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=1) # B, 16, 3 + + res = { + 'vertices': vertices, + 'face_points': face_points + } + + return res + + + + def get_depth(self,vert, resolution=256, cameras=None): + + faces = self.model.faces + # compute the transformation matrix with eg3d + intrisics_standard_dict = {"focal": [5000.0 / 1024 * resolution / 0.75, 5000.0 / 1024 * resolution / 0.75], + "princpt": [resolution / 2, resolution / 2]} + # intrisics_standard = np.array( [[5000.0, 0.0, resolution/2, 0.0], [0.0, 5000.0, resolution/2.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) + # normalized_transformation_in_realworld = np.array(render_kwargs['world2camera_matrix']) + R = np.eye(3) + angle = np.pi + R[1, 1] = np.cos(angle) + R[1, 2] = -np.sin(angle) + R[2, 1] = np.sin(angle) + R[2, 2] = np.cos(angle) + + R = torch.from_numpy(R).float().to(self.model.shapedirs.device).unsqueeze(0).repeat(self.batch_size, 1, + 1) # self.batch_size x 3 x 3 + + vertices_pyrender = torch.matmul(vert, R) # 1 x 6890 x 3 + # normalized_camerapose_in_pyrender = np.array(render_kwargs['normalized_camerapose_in_pyrender']) + + # color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + images = [] + for i in range(self.batch_size): + camera_pose = cameras[i, :16].reshape(4, 4) + + camerapose_in_pyrender = np.linalg.inv(camera_pose) + camerapose_in_pyrender[[1, 2]] *= -1 + camerapose_in_pyrender = np.linalg.inv(camerapose_in_pyrender) + + # print(vertices_pyrender.shape, vertices_pyrender[i].shape,camerapose_in_pyrender.shape) + image_camera_rotate = self.render_depth(np.ones((resolution, resolution, 3)) * 255, + vertices_pyrender[i].detach().cpu().numpy(), faces, + intrisics_standard_dict, + color=(0.4, 0.5, 0.9, 1.0), + cam_pose=camerapose_in_pyrender) + + image_camera_rotate = image_camera_rotate[None, :, :, :] # 1 x 256 x 256 x 3 + image_camera_rotate = np.transpose(image_camera_rotate, (0, 3, 1, 2)) # 1 x 3 x 256 x 256 + images.append(image_camera_rotate) + + images = np.concatenate(images, axis=0) + return images + # + def get_visualization(self, shape_pose_params, resolution=256, cameras=None): + # apply beta, alignment, translation and scale + if 'betas' in shape_pose_params: + raise NotImplementedError + betas = shape_pose_params['betas'] + misc.assert_shape(betas, [self.batch_size, self.num_betas]) + else: + betas = None + # scale = shape_pose_params['scale'] + # transl = shape_pose_params['transl'] + if 'scale' in shape_pose_params: + raise NotImplementedError + scale = shape_pose_params['scale'] + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if 'transl' in shape_pose_params: + raise NotImplementedError + transl = shape_pose_params['transl'] + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + + + body_pose = shape_pose_params['pose'] + + + misc.assert_shape(scale, [self.batch_size, 1]) + misc.assert_shape(transl, [self.batch_size, 3]) + misc.assert_shape(body_pose, [self.batch_size, 6]) + + cameras = cameras.detach().cpu().numpy() # N, 25 + + shaed_output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=None, + return_shaped=False + ) + vertices_no_pose = shaed_output.vertices + faces = self.model.faces + + align_joint_coordinate = self.get_align_coordinate(vertices_no_pose) + vertices_no_pose = vertices_no_pose + vertices_no_pose -= align_joint_coordinate + + vertices_no_pose += transl.view(self.batch_size, 1, 3) + vertices_no_pose *= scale.view(self.batch_size, 1, 1) + + # apply beta, pose, alignment, translation and scale + # mask pose except 11 and 14 + body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + body_pose_fill[:, 11, :] = body_pose[:, :3] + body_pose_fill[:, 14, :] = body_pose[:, 3:] + + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + + + + body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + + output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=body_pose_fill, + return_shaped=True + ) + vertices = output.vertices + joints = output.joints + + # align vertices and joints + vertices -= align_joint_coordinate + joints -= align_joint_coordinate + + # additional translation + vertices += transl.view(self.batch_size, 1, 3) + joints += transl.view(self.batch_size, 1, 3) + + # additional scale + vertices *= scale.view(self.batch_size, 1, 1) + joints *= scale.view(self.batch_size, 1, 1) + + # print(vertices[:,0].min(),vertices[:,0].max(),vertices[:,0].max() - vertices[:,0].min()) + # print(vertices[:,1].min(),vertices[:,1].max(),vertices[:,1].max() - vertices[:,1].min()) + # print(vertices[:,2].min(),vertices[:,2].max(),vertices[:,2].max() - vertices[:,2].min()) + + # nose_2d = joints[86:90] # 4 + # eye_right_2d = joints[95: 101] # 6 + # eye_left_2d = joints[101: 107] # 6 + + #points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + #points_3d = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=0) # 16 + + # compute the transformation matrix with eg3d + intrisics_standard_dict = {"focal": [5000.0/1024*resolution/0.75, 5000.0/1024*resolution/0.75], "princpt": [resolution/2, resolution/2]} + # intrisics_standard = np.array( [[5000.0, 0.0, resolution/2, 0.0], [0.0, 5000.0, resolution/2.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) + # normalized_transformation_in_realworld = np.array(render_kwargs['world2camera_matrix']) + R = np.eye(3) + angle = np.pi + R[1, 1] = np.cos(angle) + R[1, 2] = -np.sin(angle) + R[2, 1] = np.sin(angle) + R[2, 2] = np.cos(angle) + + R = torch.from_numpy(R).float().to(self.model.shapedirs.device).unsqueeze(0).repeat(self.batch_size, 1, 1) # self.batch_size x 3 x 3 + + vertices_pyrender = torch.matmul(vertices, R) # 1 x 6890 x 3 + #normalized_camerapose_in_pyrender = np.array(render_kwargs['normalized_camerapose_in_pyrender']) + + # color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + images = [] + for i in range(self.batch_size): + camera_pose = cameras[i,:16].reshape(4,4) + + camerapose_in_pyrender = np.linalg.inv(camera_pose) + camerapose_in_pyrender[[1,2]] *= -1 + camerapose_in_pyrender = np.linalg.inv(camerapose_in_pyrender) + + #print(vertices_pyrender.shape, vertices_pyrender[i].shape,camerapose_in_pyrender.shape) + image_camera_rotate = self.render_mesh(np.ones((resolution, resolution, 3)) * 255, + vertices_pyrender[i].detach().cpu().numpy(), faces, + intrisics_standard_dict, + color=(0.4, 0.5, 0.9, 1.0), + cam_pose=camerapose_in_pyrender) + + image_camera_rotate = image_camera_rotate[None, :, :, :] # 1 x 256 x 256 x 3 + image_camera_rotate = np.transpose(image_camera_rotate, (0, 3, 1, 2)) # 1 x 3 x 256 x 256 + images.append(image_camera_rotate) + + images = np.concatenate(images, axis=0) + return images diff --git a/3DPortraitGAN_pyramid/training/augment.py b/3DPortraitGAN_pyramid/training/augment.py new file mode 100644 index 0000000..7b00a4a --- /dev/null +++ b/3DPortraitGAN_pyramid/training/augment.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Augmentation pipeline from the paper +"Training Generative Adversarial Networks with Limited Data". +Matches the original implementation by Karras et al. at +https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py""" + +import numpy as np +import scipy.signal +import torch +from torch_utils import persistence +from torch_utils import misc +from torch_utils.ops import upfirdn2d +from torch_utils.ops import grid_sample_gradfix +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + +#---------------------------------------------------------------------------- +# Versatile image augmentation pipeline from the paper +# "Training Generative Adversarial Networks with Limited Data". +# +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + +@persistence.persistent_class +class AugmentPipe(torch.nn.Module): + def __init__(self, + xflip=0, rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Image-space corruptions. + self.noise = float(noise) # Probability multiplier for additive RGB noise. + self.cutout = float(cutout) # Probability multiplier for cutout. + self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. + self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, images, debug_percentile=None): + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply x-flip with probability (xflip * strength). + if self.xflip > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 2) + i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) + + # Apply 90 degree rotations with probability (rotate90 * strength). + if self.rotate90 > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 4) + i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 4)) + G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + if self.lumaflip > 0: + i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) + i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + elif num_channels == 6: + images[:, :3] = C[:, :3, :3] @ images[:, :3] + C[:, :3, 3:] + images[:, 3:] = C[:, :3, :3] @ images[:, 3:] + C[:, :3, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + + # Apply additive RGB noise with probability (noise * strength). + if self.noise > 0: + sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std + sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) + if debug_percentile is not None: + sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) + images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma + + # Apply cutout with probability (cutout * strength). + if self.cutout > 0: + size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) + size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) + center = torch.rand([batch_size, 2, 1, 1, 1], device=device) + if debug_percentile is not None: + size = torch.full_like(size, self.cutout_size) + center = torch.full_like(center, debug_percentile) + coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) + coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) + mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) + mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) + mask = torch.logical_or(mask_x, mask_y).to(torch.float32) + images = images * mask + + return images + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/training/crosssection_utils.py b/3DPortraitGAN_pyramid/training/crosssection_utils.py new file mode 100644 index 0000000..72d49f2 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/crosssection_utils.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +def sample_cross_section(G, ws, resolution=256, w=1.2): + axis=0 + A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij') + A, B = A.reshape(-1, 1), B.reshape(-1, 1) + C = torch.zeros_like(A) + coordinates = [A, B] + coordinates.insert(axis, C) + coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1) + + sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma'] + return sigma.reshape(-1, 1, resolution, resolution) + +# if __name__ == '__main__': +# sample_crossection(None) \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/dataset.py b/3DPortraitGAN_pyramid/training/dataset.py new file mode 100644 index 0000000..4c348ca --- /dev/null +++ b/3DPortraitGAN_pyramid/training/dataset.py @@ -0,0 +1,565 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import os +import numpy as np +import zipfile +import PIL.Image +import json +import torch +import dnnlib +try: + import pyspng +except ImportError: + pyspng = None + +#---------------------------------------------------------------------------- + +def matrix2angle(R): + """ + https://github.com/sizhean/panohead + compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf + refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv + todo: check and debug + Args: + R: (3,3). rotation matrix + Returns: + x: yaw + y: pitch + z: roll + """ + if R[2, 0] > 0.998: + z = 0 + x = np.pi / 2 + y = z + atan2(-R[0, 1], -R[0, 2]) + elif R[2, 0] < -0.998: + z = 0 + x = -np.pi / 2 + y = -z + atan2(R[0, 1], R[0, 2]) + else: + x = asin(R[2, 0]) + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + + if abs(y) > np.pi/2: + if x > 0: + x = np.pi - x + else: + x = -np.pi - x + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + return x, y, z + + +def get_poseangle(eg3dparams): + ''' + https://github.com/sizhean/panohead + ''' + convert = np.array([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ]).astype(np.float32) + + entry_cam = np.array([float(p) for p in eg3dparams][:16]).reshape((4,4)) + + world2cam = np.linalg.inv(entry_cam@convert) + pose = matrix2angle(world2cam[:3,:3]) + angle = [p * 180 / np.pi for p in pose] + + return angle + + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels = False, # Enable conditioning labels? False = label dimension is zero. + xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + random_seed = 0, # Random seed to use when applying max_size. + rebal_raw_idx = None, # Rebalance the dataset by sampling from the raw_idx list + data_rebalance=False, # Rebalance the dataset by sampling from the raw_idx list + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._raw_labels = None + self._raw_poses = None + self._label_shape = None + self._pose_shape = None + + + if data_rebalance: + raise NotImplementedError + assert rebal_raw_idx is not None, "rebal_raw_idx must be provided if data_rebalance is True" + self._raw_idx = rebal_raw_idx + else: + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + + + self._raw_idx = self._filter_samples() + + # Apply max_size. + if (max_size is not None) and (self._raw_idx.size > max_size): + raise NotImplementedError + np.random.RandomState(random_seed).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + def _filter_samples(self): # to be overridden by subclass + raise NotImplementedError + + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels,self._raw_poses = self._load_raw_labels() if self._use_labels else None + + if self._raw_labels is None: + raise Exception("_raw_labels is None") + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + self._raw_labels_std = self._raw_labels.std(0) + + + if self._raw_poses is None: + raise Exception("_raw_poses is None") + self._raw_poses = np.zeros([self._raw_poses[0], 0], dtype=np.float32) + + assert isinstance(self._raw_poses, np.ndarray) + assert self._raw_poses.shape[0] == self._raw_shape[0] + assert self._raw_poses.dtype in [np.float32, np.int64] + if self._raw_poses.dtype == np.int64: + assert self._raw_poses.ndim == 1 + assert np.all(self._raw_poses >= 0) + self._raw_poses_std = self._raw_poses.std(0) + + return self._raw_labels + + def _get_raw_poses(self): + if self._raw_poses is None: + _ = self._get_raw_labels() + #raise Exception("please run _get_raw_labels first") + + return self._raw_poses + + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None, _raw_poses=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + + + + def __getitem__(self, idx): + + + label = self.get_label(idx) + pose = self.get_coarse_pose(idx) + + # image = self._load_raw_image(self._raw_idx[idx]) + # assert isinstance(image, np.ndarray) + # assert list(image.shape) == self.image_shape + # assert image.dtype == np.uint8 + # if self._xflip[idx]: + # assert image.ndim == 3 # CHW + # image = image[:, :, ::-1] + # # # flip label + # # label = self.flip_yaw(label) + # # # flip pose + # # pose[[1, 2, 4, 5]] *= -1 + + image = self.get_image(idx) + + + return image, label,pose + + def flip_yaw(self, c): + pose_matrix = c.copy() + flipped = pose_matrix[:16].reshape(4,4) + flipped[0, 1] *= -1 + flipped[0, 2] *= -1 + flipped[1, 0] *= -1 + flipped[2, 0] *= -1 + flipped[0, 3] *= -1 + + flipped = flipped.reshape(16) + pose_matrix[:16] = flipped + + return pose_matrix + + def get_image(self, idx): + image = self._load_raw_image(self._raw_idx[idx]) + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + + return image.copy() + + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]].copy() + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + + if self._xflip[idx]: + assert label.shape == (25,) + label[[1, 2, 3, 4, 8]] *= -1 + + return label + + def get_coarse_pose(self, idx): + pose = self._get_raw_poses()[self._raw_idx[idx]].copy() + if pose.dtype == np.int64: + raise TypeError("pose should be float32") + onehot = np.zeros(self.pose_shape, dtype=np.float32) + onehot[pose] = 1 + pose = onehot + + if self._xflip[idx]: + pose_flip = pose.copy() + pose_flip[[1, 2, 4, 5]] *= -1 + + return pose_flip + + else: + return pose + + + + def get_details(self, idx): + d = dnnlib.EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = (int(self._xflip[idx]) != 0) + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + # d.pose = self.get_coarse_pose(idx).copy() + + return d + + def get_label_std(self): + return self._raw_labels_std + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def pose_shape(self): + if self._pose_shape is None: + self._get_raw_labels() + if self._raw_poses.dtype == np.int64: + self._pose_shape = [int(np.max(self._raw_poses)) + 1] + else: + self._pose_shape = self._raw_poses.shape[1:] + return list(self._pose_shape) + + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + +#---------------------------------------------------------------------------- + +class ImageFolderDataset(Dataset): + def __init__(self, + path, # Path to directory or zip. + back_repeat = None, + resolution = None, # Ensure specific resolution, None = highest available. + data_rebalance_idx_file = None, + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self.min_yaw = 0 + self.max_yaw = 180 + self.max_pitch = 90 + self.back_repeat = 1 if back_repeat is None else back_repeat + self._path = path + self._zipfile = None + + if os.path.isdir(self._path): + raise NotImplementedError('Does not support directories yet') + self._type = 'dir' + self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} + elif self._file_ext(self._path) == '.zip': + self._type = 'zip' + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError('Path must point to a directory or zip') + + if data_rebalance_idx_file is not None: + raise NotImplementedError('data_rebalance is not implemented yet') + rebal_idx_list_path =data_rebalance_idx_file + #print('load rebal_idx_list from ',rebal_idx_list_path) + with open(rebal_idx_list_path, 'r') as f: + rebal_raw_idx = json.load(f) + rebal_raw_idx = np.array(rebal_raw_idx) + else: + rebal_raw_idx = None + + + PIL.Image.init() + self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) + if len(self._image_fnames) == 0: + raise IOError('No image files found in the specified path') + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): + raise IOError('Image files do not match the specified resolution') + super().__init__(name=name, raw_shape=raw_shape, rebal_raw_idx = rebal_raw_idx,**super_kwargs) + + + def _filter_samples(self): + if self.back_repeat>1: + raw_labels = self._get_raw_labels()[self._raw_idx] + label_list = [] + for entry in raw_labels: + label_list.append(get_poseangle(entry)) + poses = np.array(label_list) + # find [min_yaw, max_yaw] boolean + valid = (np.abs(poses[:,0])>=self.min_yaw) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + # find back boolean: [max(90, self.min_yaw), max_yaw] + back_valid = (np.abs(poses[:,0])>= max(90, self.min_yaw)) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + if not np.all(valid): + print(f"filtering samples by pose: ratio = {valid.sum()}/{len(self._raw_idx)}") + # boolean to index + valid_idx = self._raw_idx[valid] + back_idx = self._raw_idx[back_valid] + front_idx = np.array(list(set(valid_idx) - set(back_idx))) + + front_num = valid.sum()-len(back_idx) + front_back_ratio_min = front_num/2/len(back_idx) + print(f"if back num be the half of front num, at least repeat ({int(front_back_ratio_min)}) times.") + back_repeat = max(int(front_num/2/len(back_idx)), self.back_repeat) + + + + + # TODO: support the repeat times < 1 + # repeat [max(90, self.min_yaw), max_yaw] for multiple times + back_repeat_idx = np.tile(back_idx, back_repeat) + # merge front index and repeated back index + new_idx = np.concatenate((front_idx, back_repeat_idx)) + print(f"Repeat {len(back_idx)} back images till abs({self.max_yaw}) degree {back_repeat} times") + return new_idx + else: + return self._raw_idx + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == 'zip' + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == 'dir': + return open(os.path.join(self._path, fname), 'rb') + if self._type == 'zip': + return self._get_zipfile().open(fname, 'r') + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if pyspng is not None and self._file_ext(fname) == '.png': + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = 'dataset.json' + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)['labels'] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] + labels = np.array(labels) + labels = np.squeeze(labels) + #print('labels shape', labels.shape) # N, 31 + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + + poses = labels[:,25:] + labels = labels[:,:25] + + # print('labels shape', labels.shape) # N, 25 + # print('poses shape', poses.shape) # N, 6 + + return labels, poses + + +#---------------------------------------------------------------------------- + + +class MaskLabeledDataset(ImageFolderDataset): + + def __init__(self, + img_path, # Path to directory or zip. + seg_path, # Path to directory or zip. + back_repeat = None, + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self.min_yaw = 0 + self.max_yaw = 180 + self.max_pitch = 90 + self.back_repeat = 1 if back_repeat is None else back_repeat + super().__init__(path=img_path, back_repeat = None,**super_kwargs) + + self._seg_dataset = ImageFolderDataset(seg_path, **super_kwargs) + + # Build the mapping from seg fname to seg raw index + seg_dict = {os.path.splitext(fname)[0]: idx for idx, fname in enumerate(self._seg_dataset._image_fnames)} + + # Build the mapping from index to seg raw index + self._seg_raw_idx = [] + for raw_idx in self._raw_idx: + fname = self._image_fnames[raw_idx] + key = os.path.splitext(fname)[0] + self._seg_raw_idx.append(seg_dict[key]) + self._seg_raw_idx = np.array(self._seg_raw_idx) + + def _filter_samples(self): + if self.back_repeat>1: + raw_labels = self._get_raw_labels()[self._raw_idx] + label_list = [] + for entry in raw_labels: + label_list.append(get_poseangle(entry)) + poses = np.array(label_list) + # find [min_yaw, max_yaw] boolean + valid = (np.abs(poses[:,0])>=self.min_yaw) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + # find back boolean: [max(90, self.min_yaw), max_yaw] + back_valid = (np.abs(poses[:,0])>= max(90, self.min_yaw)) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + if not np.all(valid): + print(f"filtering samples by pose: ratio = {valid.sum()}/{len(self._raw_idx)}") + # boolean to index + valid_idx = self._raw_idx[valid] + back_idx = self._raw_idx[back_valid] + front_idx = np.array(list(set(valid_idx) - set(back_idx))) + + front_num = valid.sum()-len(back_idx) + front_back_ratio_min = front_num/2/len(back_idx) + print(f"if back num be the half of front num, at least repeat ({int(front_back_ratio_min)}) times.") + back_repeat = max(int(front_num/2/len(back_idx)), self.back_repeat) + + + + + # TODO: support the repeat times < 1 + # repeat [max(90, self.min_yaw), max_yaw] for multiple times + back_repeat_idx = np.tile(back_idx, back_repeat) + # merge front index and repeated back index + new_idx = np.concatenate((front_idx, back_repeat_idx)) + print(f"Repeat {len(back_idx)} back images till abs({self.max_yaw}) degree {back_repeat} times") + return new_idx + else: + return self._raw_idx + + + + def __getitem__(self, idx): + # already flipped in the ImageFolderDataset + image = self.get_image(idx) + mask = self._seg_dataset.get_image(idx) + label = self.get_label(idx) + pose = self.get_coarse_pose(idx) + + + return image.copy(), mask.copy(), label,pose + diff --git a/3DPortraitGAN_pyramid/training/dual_discriminator.py b/3DPortraitGAN_pyramid/training/dual_discriminator.py new file mode 100644 index 0000000..1c753ec --- /dev/null +++ b/3DPortraitGAN_pyramid/training/dual_discriminator.py @@ -0,0 +1,502 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Discriminator architectures from the paper +"Efficient Geometry-aware 3D Generative Adversarial Networks".""" + +import numpy as np +import torch +from torch_utils import persistence +from torch_utils.ops import upfirdn2d +from training.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue + + +@persistence.persistent_class +class SingleDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + sr_upsample_factor=1, # Ignored for SingleDiscriminator + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + img = img['image'] + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- + +def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'): + if filter_mode == 'antialiased': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=True) + elif filter_mode == 'classic': + ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2) + ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), + mode='bilinear', align_corners=False) + ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1) + elif filter_mode == 'none': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False) + elif type(filter_mode) == float: + assert 0 < filter_mode < 1 + + filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=True) + aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=False) + ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered + + return ada_filtered_64 + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + disc_c_noise=0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + self.disc_c_noise = disc_c_noise + + def forward(self, img, c, update_emas=False, **block_kwargs): + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + img = torch.cat([img['image'], image_raw], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DummyDualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + + self.raw_fade = 1 + + def forward(self, img, c, update_emas=False, **block_kwargs): + self.raw_fade = max(0, self.raw_fade - 1 / (500000 / 32)) + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], + f=self.resample_filter) * self.raw_fade + img = torch.cat([img['image'], image_raw], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- +from training.networks_stylegan2 import FullyConnectedLayer + + +class PoseShapeAwareDualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + seg_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + disc_c_noise=0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. + explicitly_symmetry=False, + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels = img_channels * 2 + seg_channels + self.camera_param_dim = c_dim + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.pose_branch = DPoseBranch(num_betas=10, in_channel=channels_dict[4]*4*4) + self.c_dim += self.pose_branch.output_dim + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if self.c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if self.c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + self.disc_c_noise = disc_c_noise + + self.explicitly_symmetry = explicitly_symmetry + + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + + def predict_pose(self, img, c,update_emas=False, **block_kwargs): + + + if self.explicitly_symmetry: + theta = torch.atan2(c[:, [11]], c[:, [3]]) # math.atan2(z, x) + is_left = (theta >= -np.pi / 2) & (theta <= np.pi / 2) + + img_flip = torch.flip(img['image'], dims=[3]) + img_flip_raw = torch.flip(img['image_raw'], dims=[3]) + seg_flip = torch.flip(img['image_mask'], dims=[3]) + + is_left_img = is_left.unsqueeze(2).unsqueeze(3) + input_img = torch.where(is_left_img, img_flip, img['image']) # if left, flip image + misc.assert_shape(input_img, img_flip.shape ) + + is_left_img_raw = is_left.unsqueeze(2).unsqueeze(3) + input_img_raw = torch.where(is_left_img_raw, img_flip_raw, img['image_raw']) # if left, flip image_raw + misc.assert_shape(input_img_raw, img_flip_raw.shape ) + + is_left_seg = is_left.unsqueeze(2).unsqueeze(3) + input_seg = torch.where(is_left_seg, seg_flip, img['image_mask']) # if left, flip seg + misc.assert_shape(input_seg, seg_flip.shape ) + + img = {'image': input_img, 'image_raw': input_img_raw, 'image_mask': input_seg} + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + pose_params = self.pose_branch(pose_branch_input_feature, c) + + flip_pose_params = pose_params.clone() + flip_pose_params[:, [1, 2, 4, 5]] *= -1 # flip y and z axis angles + + pose_params = torch.where(is_left, flip_pose_params, pose_params) + + + else: + raise NotImplementedError + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + pose_params = self.pose_branch(pose_branch_input_feature, c) + + + return pose_params,pose_branch_input_feature + + def forward(self, img, c, gt_pose = None, update_emas=False, **block_kwargs): + + if self.explicitly_symmetry: + + pose_params,_ = self.predict_pose(img, c, update_emas, **block_kwargs) + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + + else: + raise NotImplementedError + pose_params, pose_branch_input_feature = self.predict_pose(img, c, update_emas, **block_kwargs) + + if gt_pose is not None: + #raise NotImplementedError + c = torch.cat([c, gt_pose], dim=1) + else: + pose_label = pose_params.detach() # detach + c = torch.cat([c, pose_label], dim=1) + + cmap = None + if self.c_dim > 0: + if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise + cmap = self.mapping(None, c) + # x = self.b4(x, img, cmap) + x = self.b4(flatten_x=pose_branch_input_feature, cmap=cmap) + return x, pose_params + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +from torch_utils import misc + + +class DPoseBranch(torch.nn.Module): + def __init__(self, num_betas, in_channel): + super().__init__() + self.num_betas = num_betas + hidden_dim = 64 + self.in_channel = in_channel + # + # predict_betas = predict_transl = predict_scale = False + # predict_pose = True + + out_dim = 6 + + # if predict_betas: + # out_dim += num_betas + # if predict_transl: + # out_dim += 3 + # if predict_scale: + # out_dim += 1 + # if predict_pose: + # out_dim += 6 + + self.in_channel += 25 # c dim + + self.output_dim = out_dim + self.net = torch.nn.Sequential( + # linear + # FullyConnectedLayer(self.in_channel, hidden_dim), + # torch.nn.LeakyReLU(0.2), + # FullyConnectedLayer(hidden_dim, self.output_dim) # betas, scale, transl, rots of neck and head + FullyConnectedLayer(self.in_channel, 2048, activation='lrelu'), + FullyConnectedLayer(2048, 512, activation='lrelu'), + FullyConnectedLayer(512, 128, activation='lrelu'), + FullyConnectedLayer(128, 32, activation='lrelu'), + FullyConnectedLayer(32, self.output_dim) + ) + + + def forward(self, feature, camera_parameters): + # misc.assert_shape(feature, [None, self.in_channel]) + # misc.assert_shape(camera_parameters, [None, 25]) + feature = torch.cat([feature, camera_parameters], dim=1) + + pose = self.net(feature) # (B, num_betas + 1 + 3 + 6) + + return pose \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/loss.py b/3DPortraitGAN_pyramid/training/loss.py new file mode 100644 index 0000000..9a9cfaa --- /dev/null +++ b/3DPortraitGAN_pyramid/training/loss.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Loss functions.""" + +import numpy as np +import torch +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing +from torch_utils import misc +import copy + + +# ---------------------------------------------------------------------------- + +class Loss: + def accumulate_gradients(self, phase, real_img, real_seg, real_c, real_pose, gen_z, gen_c, gen_pose,gain, cur_nimg, + cur_nimg_start): # to be overridden by subclass + raise NotImplementedError() + + +# ---------------------------------------------------------------------------- + +class StyleGAN2Loss(Loss): + def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, r1_gamma_seg=1000,style_mixing_prob=0, pl_weight=0, + density_noise_fade_kimg=0, + pl_batch_shrink=2, pl_decay=0.01, + pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, + neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, + neural_rendering_resolution_fade_kimg=0, + gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased', + thickness=None, + pose_loss_weight = None, input_pose_params_reg_loss_weight = None,input_pose_params_reg_loss_kimg = None, + rank=None,bcg_reg_prob=0 + ): + super().__init__() + self.device = device + self.G = G + self.D = D + self.augment_pipe = augment_pipe + self.r1_gamma = r1_gamma + self.r1_gamma_seg = r1_gamma_seg + self.style_mixing_prob = style_mixing_prob + self.pl_weight = pl_weight + self.pl_batch_shrink = pl_batch_shrink + self.pl_decay = pl_decay + self.pl_no_weight_grad = pl_no_weight_grad + self.pl_mean = torch.zeros([], device=device) + self.blur_init_sigma = blur_init_sigma + self.blur_fade_kimg = blur_fade_kimg + self.r1_gamma_init = r1_gamma_init + self.r1_gamma_fade_kimg = r1_gamma_fade_kimg + self.neural_rendering_resolution_initial = neural_rendering_resolution_initial + self.neural_rendering_resolution_final = neural_rendering_resolution_final + self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg + self.density_noise_fade_kimg = density_noise_fade_kimg + self.gpc_reg_fade_kimg = gpc_reg_fade_kimg + self.gpc_reg_prob = gpc_reg_prob + self.dual_discrimination = dual_discrimination + self.filter_mode = filter_mode + self.resample_filter = upfirdn2d.setup_filter([1, 3, 3, 1], device=device) + self.blur_raw_target = True + self.bcg_reg_prob = bcg_reg_prob + assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1) + + + self.thickness = thickness + self.pose_loss_weight = pose_loss_weight + self.input_pose_params_reg_loss_weight = input_pose_params_reg_loss_weight + self.input_pose_params_reg_loss_kimg = input_pose_params_reg_loss_kimg + + + # for snap + self.swapping_prob = None + self.neural_rendering_resolution = None + self.blur_sigma = None + + + self.rank = rank + + def run_G(self, z, c, pose_params, swapping_prob, neural_rendering_resolution, update_emas=False): + if swapping_prob is not None: + c_swapped = torch.roll(c.clone(), 1, 0) + p_swapped = torch.roll(pose_params.clone(), 1, 0) + rand_ = torch.rand((c.shape[0], 1), device=c.device) + c_gen_conditioning = torch.where(rand_ < swapping_prob, c_swapped, c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, pose_params) + else: + c_gen_conditioning = torch.zeros_like(c) + pose_params_conditioning = torch.zeros([c.shape[0],6]).to(c.device) + + ws = self.G.mapping(z, c_gen_conditioning, pose_params_conditioning,update_emas=update_emas) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c,pose_params, update_emas=False)[:, cutoff:] + + if self.bcg_reg_prob > 0: + ws_swapped = torch.roll(ws.clone(), 1, 0) + ws_bcg = torch.where(torch.rand((ws.shape[0], 1, 1), device=ws.device) < self.bcg_reg_prob, ws_swapped, ws) + else: + ws_bcg = None + + + gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, + update_emas=update_emas, + apply_def=True, pose_params=pose_params,ws_bcg = ws_bcg + ) + return gen_output, ws + + + + def run_D(self, img, c, gt_pose=None, blur_sigma=0, blur_sigma_raw=0, update_emas=False): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + + if self.augment_pipe is not None: + raise NotImplementedError + augmented_pair = self.augment_pipe(torch.cat([img['image'], + torch.nn.functional.interpolate(img['image_raw'], + size=img['image'].shape[2:], + mode='bilinear', + antialias=True)], + dim=1)) + img['image'] = augmented_pair[:, :img['image'].shape[1]] + img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], + size=img['image_raw'].shape[2:], mode='bilinear', + antialias=True) + + logits, pose = self.D(img, c, gt_pose=gt_pose, update_emas=update_emas) + return logits, pose + + def run_D_pose_prediction(self, img, c, blur_sigma=0): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + + if self.augment_pipe is not None: + augmented_pair = self.augment_pipe(torch.cat([img['image'], + torch.nn.functional.interpolate(img['image_raw'], + size=img['image'].shape[2:], + mode='bilinear', + antialias=True)], + dim=1)) + img['image'] = augmented_pair[:, :img['image'].shape[1]] + img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], + size=img['image_raw'].shape[2:], mode='bilinear', + antialias=True) + + pose, _ = self.D.predict_pose(img, c) + return pose + + def get_pose_params_D(self, real_img, real_seg, real_c, cur_nimg): + blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), + 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 + r1_gamma = self.r1_gamma + + alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 + swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None + + if not isinstance(real_img,dict): + if self.neural_rendering_resolution_final is not None: + alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) + neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * ( + 1 - alpha) + self.neural_rendering_resolution_final * alpha)) + else: + neural_rendering_resolution = self.neural_rendering_resolution_initial + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + if self.blur_raw_target: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + else: + assert 'image_raw' in real_img.keys(), 'image_raw is not in real_img.keys()' + assert 'image' in real_img.keys(), 'image is not in real_img.keys()' + + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, 'image_mask': real_img_tmp_image_mask} + + predicted_real_pose = self.run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma) + return predicted_real_pose + + def get_pose_params_G(self,z,c): + predicted_pose = self.G.get_pose_params(z,c) + return predicted_pose + + def accumulate_gradients(self, phase, real_img, real_seg, real_c, real_pose, + gen_z, gen_c,gen_pose, + gain, cur_nimg, cur_nimg_start): + assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] + if self.G.rendering_kwargs.get('density_reg', 0) == 0: + phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) + if self.r1_gamma == 0: + phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) + blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), + 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 + self.blur_sigma = blur_sigma + r1_gamma = self.r1_gamma + self.G.rendering_kwargs["density_noise"] = max(1 - cur_nimg / (self.density_noise_fade_kimg * 1e3), + 0) if self.density_noise_fade_kimg > 0 else 0 + + alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 + swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None + self.swapping_prob = swapping_prob + + if self.neural_rendering_resolution_final is not None: + alpha = min((cur_nimg-cur_nimg_start) / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) + neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * ( + 1 - alpha) + self.neural_rendering_resolution_final * alpha)) + else: + neural_rendering_resolution = self.neural_rendering_resolution_initial + + self.neural_rendering_resolution = neural_rendering_resolution + + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + + + if self.blur_raw_target: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + + input_pose_params = self.get_pose_params_G(gen_z,gen_c) + + + for i in range(input_pose_params.shape[1]): + training_stats.report('pose_scale/input_pose_params_{}'.format(i), + (input_pose_params[:, i]).abs().mean() / np.pi * 180) + + + # Gmain: Maximize logits for generated images. + if phase in ['Gmain', 'Gboth']: + with torch.autograd.profiler.record_function('Gmain_forward'): + gen_img, _gen_ws = self.run_G(gen_z, gen_c, input_pose_params, swapping_prob=swapping_prob, + neural_rendering_resolution=neural_rendering_resolution) + + + gen_logits, predict_gen_pose = self.run_D(gen_img, gen_c, gt_pose=None, blur_sigma=blur_sigma) + training_stats.report('Loss/scores/fake_posed', gen_logits) + training_stats.report('Loss/signs/fake_posed', gen_logits.sign()) + loss_Gmain = torch.nn.functional.softplus(-gen_logits) + + # Lpreg + if self.input_pose_params_reg_loss_weight>0 and cur_nimg<(self.input_pose_params_reg_loss_kimg+200) * 1e3: + + if cur_nimg 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'l1': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_L1', TVloss) + TVloss.mul(gain).backward() + + # Alternative density regularization + if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'monotonic-detach': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + + initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front + + perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10 + monotonic_loss.mul(gain).backward() + + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning,pose_params_conditioning, update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_monotonic-detach', TVloss) + TVloss.mul(gain).backward() + + # Alternative density regularization + if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'monotonic-fixed': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + + initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front + + perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10 + monotonic_loss.mul(gain).backward() + + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_monotonic-fixed', TVloss) + TVloss.mul(gain).backward() + + # Dmain: Minimize logits for generated images. + loss_Dgen = 0 + if phase in ['Dmain', 'Dboth']: + with torch.autograd.profiler.record_function('Dgen_forward'): + + gen_img, _gen_ws = self.run_G(gen_z, gen_c, input_pose_params, swapping_prob=swapping_prob, + neural_rendering_resolution=neural_rendering_resolution, update_emas=True) + gen_logits, predict_gen_pose = self.run_D(gen_img, gen_c, gt_pose=None, blur_sigma=blur_sigma, + update_emas=True) + + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + loss_Dgen = torch.nn.functional.softplus( gen_logits) # -log (1 - sigmoid(gen_logits)) = log (1 + exp(gen_logits)) = softplus(gen_logits) + + pose_param_loss = (predict_gen_pose - input_pose_params).square().sum([1]) * self.pose_loss_weight + training_stats.report('Loss/D/Poseloss', pose_param_loss) + + for i in range(predict_gen_pose.shape[1]): + training_stats.report('Loss_pose/fake_{}'.format(i), + (predict_gen_pose[:, i] - input_pose_params[:, i]).abs().mean() / np.pi * 180) + training_stats.report('pose_scale/fake_{}'.format(i), + (predict_gen_pose[:, i]).abs().mean() / np.pi * 180) + + + + + with torch.autograd.profiler.record_function('Dgen_backward'): + (loss_Dgen + pose_param_loss).mean().mul(gain).backward() + + + # Dmain: Maximize logits for real images. + # Dr1: Apply R1 regularization. + if phase in ['Dmain', 'Dreg', 'Dboth']: + name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' + with torch.autograd.profiler.record_function(name + '_forward'): + real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, 'image_mask': real_img_tmp_image_mask} + + real_logits, predicted_real_pose = self.run_D(real_img_tmp, real_c, + gt_pose=None, + blur_sigma=blur_sigma) + + training_stats.report('Loss/scores/real', real_logits) + training_stats.report('Loss/signs/real', real_logits.sign()) + + + for i in range(predicted_real_pose.shape[1]): + training_stats.report('Loss_pose/real_{}'.format(i), ( + predicted_real_pose[:, i] - real_pose[:, i]).abs().mean() / np.pi * 180) + training_stats.report('pose_scale/real_{}'.format(i), + (predicted_real_pose[:, i]).abs().mean() / np.pi * 180) + + + loss_Dreal = 0 + if phase in ['Dmain', 'Dboth']: + loss_Dreal = torch.nn.functional.softplus( + -real_logits) # - log sigmoid(real_logits) = log (1 + exp(-real_logits)) = softplus(-real_logits) + training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) + training_stats.report('Loss/D/loss_gen', loss_Dgen) + training_stats.report('Loss/D/loss_real', loss_Dreal) + + + # + + loss_Dr1 = 0 + if phase in ['Dreg', 'Dboth']: + if self.dual_discrimination: + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], + inputs=[real_img_tmp['image'], real_img_tmp['image_raw'], real_img_tmp['image_mask']], + create_graph=True, only_inputs=True) + r1_grads_image = r1_grads[0] + r1_grads_image_raw = r1_grads[1] + r1_grads_image_mask = r1_grads[2] + r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3]) + r1_penalty_seg = r1_grads_image_mask.square().sum([1, 2, 3]) + else: # single discrimination + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_mask']], + create_graph=True, only_inputs=True) + r1_grads_image = r1_grads[0] + r1_grads_image_mask = r1_grads[1] + r1_penalty = r1_grads_image.square().sum([1, 2, 3]) + r1_penalty_seg = r1_grads_image_mask.square().sum([1, 2, 3]) + loss_Dr1 = r1_penalty * (self.r1_gamma / 2) + r1_penalty_seg * (self.r1_gamma_seg / 2) + training_stats.report('Loss/r1_penalty', r1_penalty) + training_stats.report('Loss/r1_penalty_seg', r1_penalty_seg) + training_stats.report('Loss/D/reg', loss_Dr1) + + + with torch.autograd.profiler.record_function(name + '_backward'): + (loss_Dreal + loss_Dr1).mean().mul(gain).backward() + +# ---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/training/networks_stylegan2.py b/3DPortraitGAN_pyramid/training/networks_stylegan2.py new file mode 100644 index 0000000..cd56a4a --- /dev/null +++ b/3DPortraitGAN_pyramid/training/networks_stylegan2.py @@ -0,0 +1,1138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Network architectures from the paper +"Analyzing and Improving the Image Quality of StyleGAN". +Matches the original implementation of configs E-F by Karras et al. at +https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" + +""" +3D-aware stylegan2 backbone architectures from the paper +"Mimic3D: Thriving 3D-Aware GANs via 3D-to-2D Imitation (ICCV 2023)" +https://github.com/SeanChenxy/Mimic3D/blob/main/training/networks_stylegan2.py +""" + +import numpy as np +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3], + keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, + padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, + groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, + padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + roll_out=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.roll_out = roll_out + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + affine_scale = 1 + if self.roll_out in ['b', 'a']: + affine_scale = 9 + elif self.roll_out in ['s']: + affine_scale = 3 + self.affine = FullyConnectedLayer(w_dim, in_channels * affine_scale, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn( + [out_channels, in_channels * (1, 3)[self.roll_out in ['b', 'a']], + kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution * (1, 3)[self.roll_out == 'w']])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, **_): + assert noise_mode in ['random', 'const', 'none'] + # noise_mode = 'const' + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution * (1, 3)[self.roll_out == 'w']]) + styles = self.affine(w) + if self.roll_out in ['b', 'a', 's']: + styles = styles.view(styles.shape[0], 3, styles.shape[1] // 3).view(styles.shape[0] * 3, + styles.shape[1] // 3) + if self.roll_out in ['b', 'a', ]: + x = aware3d_att(x) if self.roll_out == 'a' else aware3d(x) + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution * (1, 3)[self.roll_out == 'w']], + device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, + fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + + +def aware3d(x): + if isinstance(x, list): + x_xy, x_yz, x_zx = x + B, _, H, W = x_xy.shape + B *= 3 + else: + x_ = x.view(-1, 3, x.shape[1], x.shape[2], x.shape[3]) + x_xy, x_yz, x_zx = x_[:, 0], x_[:, 1], x_[:, 2] + B, _, H, W = x.shape + x_zy = x_yz.permute(0, 1, 3, 2) + x_xz = x_zx.permute(0, 1, 3, 2) + x_yx = x_xy.permute(0, 1, 3, 2) + + x_zy_pz = x_zy.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_xy.shape[-1]) + x_xz_pz = x_xz.mean(dim=-2, keepdim=True).repeat(1, 1, x_xy.shape[-2], 1) + x_xy_ = torch.cat([x_xy, x_zy_pz, x_xz_pz], 1) + + x_yx_px = x_yx.mean(dim=-2, keepdim=True).repeat(1, 1, x_yz.shape[-2], 1) + x_xz_px = x_xz.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_yz.shape[-1]) + x_yz_ = torch.cat([x_yx_px, x_yz, x_xz_px], 1) + + x_yx_py = x_yx.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_zx.shape[-1]) + x_zy_py = x_zy.mean(dim=-2, keepdim=True).repeat(1, 1, x_zx.shape[-2], 1) + x_zx_ = torch.cat([x_yx_py, x_zy_py, x_zx], 1) + + x = torch.cat([x_xy_[:, None], x_yz_[:, None], x_zx_[:, None]], 1).view(B, -1, H, W) + return x + + +def aware3d_att(x): + x_ = x.view(-1, 3, x.shape[1], x.shape[2], x.shape[3]) + x_cyx, x_czy, x_cxz = x_[:, 0], x_[:, 1], x_[:, 2] + + x_yxc = x_cyx.permute(0, 2, 3, 1) + x_ycz = x_czy.permute(0, 3, 1, 2) + x_yzc = x_czy.permute(0, 3, 2, 1) + x_yxz = torch.einsum('byxc,bycz->byxz', x_yxc, x_ycz) + x_yxz = torch.softmax(x_yxz, dim=-1) + x_cyx_f_czy = torch.einsum('byxz,byzc->byxc', x_yxz, x_yzc).permute(0, 3, 1, 2) + x_xyc = x_cyx.permute(0, 3, 2, 1) + x_xcz = x_cxz.permute(0, 2, 1, 3) + x_xzc = x_cxz.permute(0, 2, 3, 1) + x_xyz = torch.einsum('bxyc,bxcz->bxyz', x_xyc, x_xcz) + x_xyz = torch.softmax(x_xyz, dim=-1) + x_cyx_f_cxz = torch.einsum('bxyz,bxzc->bxyc', x_xyz, x_xzc).permute(0, 3, 2, 1) + x_cyx_ = torch.cat([x_cyx, x_cyx_f_czy, x_cyx_f_cxz], 1) + + x_zyc = x_czy.permute(0, 2, 3, 1) + x_zcx = x_cxz.permute(0, 3, 1, 2) + x_zxc = x_cxz.permute(0, 3, 2, 1) + x_zyx = torch.einsum('bzyc,bzcx->bzyx', x_zyc, x_zcx) + x_zyx = torch.softmax(x_zyx, dim=-1) + x_czy_f_cxz = torch.einsum('bzyx,bzxc->bzyc', x_zyx, x_zxc).permute(0, 3, 1, 2) + x_ycx = x_cyx.permute(0, 2, 1, 3) + x_yzx = torch.einsum('byzc,bycx->byzx', x_yzc, x_ycx) + x_yzx = torch.softmax(x_yzx, dim=-1) + x_czy_f_cyx = torch.einsum('byzx,byxc->byzc', x_yzx, x_yxc).permute(0, 3, 2, 1) + x_czy_ = torch.cat([x_czy, x_czy_f_cxz, x_czy_f_cyx], 1) + + x_xcy = x_cyx.permute(0, 3, 1, 2) + x_xzy = torch.einsum('bxzc,bxcy->bxzy', x_xzc, x_xcy) + x_xzy = torch.softmax(x_xzy, dim=-1) + x_cxz_f_cyx = torch.einsum('bxzy,bxyc->bxzc', x_xzy, x_xyc).permute(0, 3, 1, 2) + x_zcy = x_czy.permute(0, 2, 1, 3) + x_zxy = torch.einsum('bzxc,bzcy->bzxy', x_zxc, x_zcy) + x_zxy = torch.softmax(x_zxy, dim=-1) + x_cxz_f_czy = torch.einsum('bzxy,bzyc->bzxc', x_zxy, x_zyc).permute(0, 3, 2, 1) + x_cxz_ = torch.cat([x_cxz, x_cxz_f_cyx, x_cxz_f_czy], 1) + + x = torch.cat([x_cyx_[:, None], x_czy_[:, None], x_cxz_[:, None]], 1).view(x.shape[0], -1, x.shape[2], x.shape[3]) + return x + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False, + roll_out=None): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.roll_out = roll_out + + affine_scale = 1 + if self.roll_out in ['b', 'a']: + affine_scale = 9 + elif self.roll_out in ['s']: + affine_scale = 3 + self.affine = FullyConnectedLayer(w_dim, in_channels * affine_scale, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn( + [out_channels, in_channels * (1, 3)[self.roll_out in ['b', 'a']], + kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + if self.roll_out in ['b', 'a', 's']: + styles = styles.view(styles.shape[0], 3, styles.shape[1] // 3).view(styles.shape[0] * 3, + styles.shape[1] // 3) + if self.roll_out in ['b', 'a', ]: + x = aware3d_att(x) if self.roll_out == 'a' else aware3d(x) + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + up=2, + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + fused_modconv_default=True, + # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + roll_out=None, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.up = up + self.roll_out = roll_out + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([(1, 3)[self.roll_out in ['b', 'a']], out_channels, resolution, + resolution * (1, 3)[self.roll_out == 'w']])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=self.up, + roll_out=roll_out, + resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, roll_out=roll_out, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last, roll_out=self.roll_out) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // self.up, + self.resolution // self.up * (1, 3)[self.roll_out == 'w']]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None and self.up > 1: + misc.assert_shape(img, [None, self.img_channels, self.resolution // self.up, + self.resolution // self.up * (1, 3)[self.roll_out == 'w']]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class Hierarchy3DAwareSynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + + aware3d_att=False + aware3d_res = [4,8,16,32,64,128,256] + add_block = 0 + + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.add_block = add_block + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + # channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = img_channels if res > 4 else 0 + out_channels = img_channels + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) and self.add_block == 0 + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + if res in aware3d_res: + block3d = Aware3DBlock(img_channels, res, w_dim, aware3d_att, + block_kwargs.copy()) + setattr(self, f'b3d{res}', block3d) + + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = img3d = None + feature_maps = {} + last_has_block3d = False + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + block3d = getattr(self, f'b3d{res}', None) + if last_has_block3d and block3d is None: + assert NotImplementedError + img = img3d.view(-1, 3, img3d.shape[-3], img.shape[-2], img.shape[-1]).view(img.shape) + # 2D Branch + x, img = block(x, img, cur_ws, **block_kwargs) # 2D Synthesis Block + + # 3D Branch + if block3d is not None: + last_has_block3d = True + img3d = block3d(img3d, img, cur_ws, block_kwargs) # 3D-Aware Block + if isinstance(img3d, list): + assert NotImplementedError + else: + feature_maps[res] = img3d + else: + assert NotImplementedError + + return feature_maps + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + + +@persistence.persistent_class +class SR3DBlock(torch.nn.Module): + def __init__(self, img_channels, img_resolution, w_dim, block_kwargs): + super().__init__() + block_kwargs['roll_out'] = 's' + self.block2 = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * 2, + up=2, + img_channels=32, is_last=True, use_fp16=False, **block_kwargs) + self.block3 = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * 2, + up=1, + img_channels=32, is_last=True, use_fp16=False, **block_kwargs) + + def forward(self, img, ws): + ws = ws[:, -1:, :].repeat(1, 3, 1) + img = img.view(img.shape[0], 3, -1, img.shape[-2], img.shape[-1]).view(img.shape[0] * 3, -1, img.shape[-2], + img.shape[-1]) + x, img2 = self.block2(img, None, ws) + x, img3 = self.block3(img2, None, ws) + img2 = img2.view(-1, 3, img2.shape[-3], img2.shape[-2], img2.shape[-1]).view(-1, 3 * img2.shape[-3], + img2.shape[-2], img2.shape[-1]) + img3 = img3.view(-1, 3, img3.shape[-3], img3.shape[-2], img.shape[-1]).view(-1, 3 * img3.shape[-3], + img3.shape[-2], img3.shape[-1]) + + return [img2, img3] + + +# ---------------------------------------------------------------------------- +@persistence.persistent_class +class Aware3DBlock(torch.nn.Module): + + def __init__(self, img_channels, img_resolution, w_dim, aware3d_att, block_kwargs): + super().__init__() + block_kwargs['roll_out'] = ('b', 'a')[aware3d_att] + up = 2 + self.block = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * up, + up=up, + img_channels=img_channels // 3, is_last=True, use_fp16=False, **block_kwargs) + + def forward(self, x, img, ws, block_kwargs): + img = img.view(img.shape[0], 3, -1, img.shape[-2], img.shape[-1]).view(img.shape[0] * 3, -1, img.shape[-2], + img.shape[-1]) + if x is not None: + img = img + x + + ws = ws[:, -1:, :].repeat(1, 3, 1) + _, img = self.block(img, None, ws, **block_kwargs) + return img + + +@persistence.persistent_class +class Hierarchy3DAwareGenerator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs={}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = Hierarchy3DAwareSynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, + **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.num_layers = 0 + + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + + trainable_iter = trainable_gen() + + if in_channels == 0 or architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, + channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + if (x if x is not None else img).device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0 or self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + x = self.conv1(x) + + assert x.dtype == dtype + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + + +#---------------------------------------------------------------------------- + + +@persistence.persistent_class +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, + W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + def extra_repr(self): + return f'group_size={self.group_size}, num_channels={self.num_channels:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) + + + def get_flatten_x(self, x, img, force_fp32=False): + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + if self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + x = x + self.fromrgb(img) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + + flatten_x = x.flatten(1) + + return flatten_x + + def forward(self, flatten_x, cmap, force_fp32=False): + # misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + # _ = force_fp32 # unused + # dtype = torch.float32 + # memory_format = torch.contiguous_format + # + # # FromRGB. + # x = x.to(dtype=dtype, memory_format=memory_format) + # if self.architecture == 'skip': + # misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + # img = img.to(dtype=dtype, memory_format=memory_format) + # x = x + self.fromrgb(img) + # + # # Main layers. + # if self.mbstd is not None: + # x = self.mbstd(x) + # x = self.conv(x) + + misc.assert_shape(flatten_x, [None, self.in_channels * self.resolution * self.resolution]) + dtype = torch.float32 + + x = self.fc(flatten_x) + x = self.out(x) + + # Conditioning. + if self.cmap_dim > 0: + misc.assert_shape(cmap, [None, self.cmap_dim]) + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/networks_stylegan3.py b/3DPortraitGAN_pyramid/training/networks_stylegan3.py new file mode 100644 index 0000000..40e5508 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/networks_stylegan3.py @@ -0,0 +1,517 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generator architecture from the paper +"Alias-Free Generative Adversarial Networks".""" + +import numpy as np +import scipy.signal +import scipy.optimize +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import filtered_lrelu +from torch_utils.ops import bias_act + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor: [batch_size, in_channels, in_height, in_width] + w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] + s, # Style tensor: [batch_size, in_channels] + demodulate = True, # Apply weight demodulation? + padding = 0, # Padding: int or [padH, padW] + input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] +): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(x.shape[0]) + out_channels, in_channels, kh, kw = w.shape + misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(s, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs. + if demodulate: + w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() + s = s * s.square().mean().rsqrt() + + # Modulate weights. + w = w.unsqueeze(0) # [NOIkk] + w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Demodulate weights. + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Apply input scaling. + if input_gain is not None: + input_gain = input_gain.expand(batch_size, in_channels) # [NI] + w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Execute as one fused op using grouped convolution. + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) + x = x.reshape(batch_size, -1, *x.shape[2:]) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + bias = True, # Apply additive bias before the activation function? + lr_multiplier = 1, # Learning rate multiplier. + weight_init = 1, # Initial standard deviation of the weight tensor. + bias_init = 0, # Initial value of the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) + bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) + self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality, 0 = no labels. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output. + num_layers = 2, # Number of mapping layers. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + # Construct layers. + self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None + features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + misc.assert_shape(z, [None, self.z_dim]) + if truncation_cutoff is None: + truncation_cutoff = self.num_ws + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = self.embed(c.to(torch.float32)) + y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() + x = torch.cat([x, y], dim=1) if x is not None else y + + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast and apply truncation. + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + if truncation_psi != 1: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisInput(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + channels, # Number of output channels. + size, # Output spatial size: int or [width, height]. + sampling_rate, # Output sampling rate. + bandwidth, # Output bandwidth. + ): + super().__init__() + self.w_dim = w_dim + self.channels = channels + self.size = np.broadcast_to(np.asarray(size), [2]) + self.sampling_rate = sampling_rate + self.bandwidth = bandwidth + + # Draw random frequencies from uniform 2D disc. + freqs = torch.randn([self.channels, 2]) + radii = freqs.square().sum(dim=1, keepdim=True).sqrt() + freqs /= radii * radii.square().exp().pow(0.25) + freqs *= bandwidth + phases = torch.rand([self.channels]) - 0.5 + + # Setup parameters and buffers. + self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) + self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) + self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. + self.register_buffer('freqs', freqs) + self.register_buffer('phases', phases) + + def forward(self, w): + # Introduce batch dimension. + transforms = self.transform.unsqueeze(0) # [batch, row, col] + freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] + phases = self.phases.unsqueeze(0) # [batch, channel] + + # Apply learned transformation. + t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) + m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. + m_r[:, 0, 0] = t[:, 0] # r'_c + m_r[:, 0, 1] = -t[:, 1] # r'_s + m_r[:, 1, 0] = t[:, 1] # r'_s + m_r[:, 1, 1] = t[:, 0] # r'_c + m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. + m_t[:, 0, 2] = -t[:, 2] # t'_x + m_t[:, 1, 2] = -t[:, 3] # t'_y + transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. + + # Transform frequencies. + phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) + + # Construct sampling grid. + theta = torch.eye(2, 3, device=w.device) + theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate + theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate + grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) + + # Compute Fourier features. + x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] + x = x + phases.unsqueeze(1).unsqueeze(2) + x = torch.sin(x * (np.pi * 2)) + x = x * amplitudes.unsqueeze(1).unsqueeze(2) + + # Apply trainable mapping. + weight = self.weight / np.sqrt(self.channels) + x = x @ weight.t() + + # Ensure correct shape. + x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] + misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', + f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + is_torgb, # Is this the final ToRGB layer? + is_critically_sampled, # Does this layer use critical sampling? + use_fp16, # Does this layer use FP16? + + # Input & output specifications. + in_channels, # Number of input channels. + out_channels, # Number of output channels. + in_size, # Input spatial size: int or [width, height]. + out_size, # Output spatial size: int or [width, height]. + in_sampling_rate, # Input sampling rate (s). + out_sampling_rate, # Output sampling rate (s). + in_cutoff, # Input cutoff frequency (f_c). + out_cutoff, # Output cutoff frequency (f_c). + in_half_width, # Input transition band half-width (f_h). + out_half_width, # Output Transition band half-width (f_h). + + # Hyperparameters. + conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. + filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. + lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. + use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. + conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. + magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. + ): + super().__init__() + self.w_dim = w_dim + self.is_torgb = is_torgb + self.is_critically_sampled = is_critically_sampled + self.use_fp16 = use_fp16 + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.conv_kernel = 1 if is_torgb else conv_kernel + self.conv_clamp = conv_clamp + self.magnitude_ema_beta = magnitude_ema_beta + + # Setup parameters and buffers. + self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) + self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) + self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) + self.register_buffer('magnitude_ema', torch.ones([])) + + # Design upsampling filter. + self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate + self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 + self.register_buffer('up_filter', self.design_lowpass_filter( + numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) + + # Design downsampling filter. + self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate + self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 + self.down_radial = use_radial_filters and not self.is_critically_sampled + self.register_buffer('down_filter', self.design_lowpass_filter( + numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) + + # Compute padding. + pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. + pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. + pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. + pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): + assert noise_mode in ['random', 'const', 'none'] # unused + misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) + misc.assert_shape(w, [x.shape[0], self.w_dim]) + + # Track input magnitude. + if update_emas: + with torch.autograd.profiler.record_function('update_magnitude_ema'): + magnitude_cur = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) + input_gain = self.magnitude_ema.rsqrt() + + # Execute affine layer. + styles = self.affine(w) + if self.is_torgb: + weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) + styles = styles * weight_gain + + # Execute modulated conv2d. + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, + padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) + + # Execute bias, filtered leaky ReLU, and clamping. + gain = 1 if self.is_torgb else np.sqrt(2) + slope = 1 if self.is_torgb else 0.2 + x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), + up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) + assert x.dtype == dtype + return x + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + assert numtaps >= 1 + + # Identity filter. + if numtaps == 1: + return None + + # Separable Kaiser low-pass filter. + if not radial: + f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return torch.as_tensor(f, dtype=torch.float32) + + # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', + f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', + f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', + f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', + f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', + f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. + num_critical = 2, # Number of critically sampled layers at the end. + first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). + first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). + last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size = 10, # Number of additional pixels outside the image. + output_scale = 0.25, # Scale factor for the output image. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.w_dim = w_dim + self.num_ws = num_layers + 2 + self.img_resolution = img_resolution + self.img_channels = img_channels + self.num_layers = num_layers + self.num_critical = num_critical + self.margin_size = margin_size + self.output_scale = output_scale + self.num_fp16_res = num_fp16_res + + # Geometric progression of layer cutoffs and min. stopbands. + last_cutoff = self.img_resolution / 2 # f_{c,N} + last_stopband = last_cutoff * last_stopband_rel # f_{t,N} + exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] + stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] + + # Compute remaining layer parameters. + sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] + sizes = sampling_rates + self.margin_size * 2 + sizes[-2:] = self.img_resolution + channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) + channels[-1] = self.img_channels + + # Construct layers. + self.input = SynthesisInput( + w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), + sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) + self.layer_names = [] + for idx in range(self.num_layers + 1): + prev = max(idx - 1, 0) + is_torgb = (idx == self.num_layers) + is_critically_sampled = (idx >= self.num_layers - self.num_critical) + use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) + layer = SynthesisLayer( + w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, + in_channels=int(channels[prev]), out_channels= int(channels[idx]), + in_size=int(sizes[prev]), out_size=int(sizes[idx]), + in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), + in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev], out_half_width=half_widths[idx], + **layer_kwargs) + name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' + setattr(self, name, layer) + self.layer_names.append(name) + + def forward(self, ws, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32).unbind(dim=1) + + # Execute layers. + x = self.input(ws[0]) + for name, w in zip(self.layer_names, ws[1:]): + x = getattr(self, name)(x, w, **layer_kwargs) + if self.output_scale != 1: + x = x * self.output_scale + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) + x = x.to(torch.float32) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', + f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/training/neural_renderer.py b/3DPortraitGAN_pyramid/training/neural_renderer.py new file mode 100644 index 0000000..a4b5b8c --- /dev/null +++ b/3DPortraitGAN_pyramid/training/neural_renderer.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch +from torch_utils import persistence +from training.networks_stylegan2 import ToRGBLayer, SynthesisNetwork + +from training.networks_stylegan2 import Hierarchy3DAwareGenerator as StyleGAN2Backbone +from training.volumetric_rendering.renderer import ImportanceRenderer +from training.volumetric_rendering.ray_sampler import RaySampler +import dnnlib +import numpy as np + + +@persistence.persistent_class +class TriPlaneGenerator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + sr_num_fp16_res = 0, + mapping_kwargs = {}, # Arguments for MappingNetwork. + rendering_kwargs = {}, + sr_kwargs = {}, + batch_size=1, + explicitly_symmetry=False, + thickness= 0.05, + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + bcg_synthesis_kwargs = synthesis_kwargs.copy() + bcg_synthesis_kwargs["channel_base"] = 32768 + bcg_synthesis_kwargs["channel_max"] = 512 + + self.z_dim=z_dim + self.c_dim=c_dim + self.w_dim=w_dim + self.img_resolution=img_resolution + self.img_channels=img_channels + + self.trigrid_channel = 12 + self.decode_channel = 32 + + self.batch_size = batch_size + self.renderer = ImportanceRenderer(w_dim = w_dim, num_ws = 14, batch_size = self.batch_size,thickness =thickness,box_warp = rendering_kwargs['box_warp']) + self.ray_sampler = RaySampler() + + self.decoder = OSGDecoder(self.trigrid_channel, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': self.decode_channel, + 'decoder_activation': rendering_kwargs['decoder_activation']}) + + self.torgb = ToRGBLayer(self.decode_channel, 3, w_dim) if rendering_kwargs.get('use_torgb_raw', False) else None + + self.bcg_synthesis = SynthesisNetwork(w_dim, img_resolution=128, + img_channels=self.decode_channel, **bcg_synthesis_kwargs) if rendering_kwargs.get('use_background', False) else None + + self.neural_rendering_resolution = 64 + self.rendering_kwargs = rendering_kwargs + + self._last_planes = None + + self.explicitly_symmetry = explicitly_symmetry + + self.avg_c = torch.tensor([[ 1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().cuda() + + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + + + + def set_batch_size(self, batch_size): + self.renderer.set_batch_size(batch_size) + + def render_meshes(self,shape_pose_params,resolution,cameras): + + return self.renderer.render_meshes(shape_pose_params, resolution, cameras) + + + + def render_planes(self, ws, planes, c, neural_rendering_resolution=None, update_emas=False, chunk = None,render_bg = True,patch_resolution=None, + apply_def=False, pose_params = None,ws_bcg=None, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + intrinsics = c[:, 16:25].view(-1, 3, 3) + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + + patch_info = [] + if patch_resolution is None: + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) + H = W = neural_rendering_resolution + else: + ray_origins, ray_directions,patch_info = self.ray_sampler.patch_forward(cam2world_matrix, intrinsics, + patch_resolution, + patch_scale=patch_resolution/neural_rendering_resolution) + H = W = patch_resolution + + + N, M, _ = ray_origins.shape + + + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + if len(planes[res_k].shape) == 4: + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + if chunk is not None: + + feature_list, depth_list, weight_list = list(), list(), list() + for _ro, _rd in zip(torch.split(ray_origins, chunk, dim=1), torch.split(ray_directions, chunk, dim=1)): + render_output = self.renderer(planes, self.decoder, _ro, + _rd, self.rendering_kwargs, apply_def = apply_def, ws = ws, pose_params = pose_params ) # channels last + + _f = render_output['rgb_final'] + _d = render_output['depth_final'] + _w = render_output['weights'] + feature_list.append(_f) + depth_list.append(_d) + weight_list.append(_w) + feature_samples = torch.cat(feature_list, 1) + depth_samples = torch.cat(depth_list, 1) + weights_samples = torch.cat(weight_list, 1) + else: + + # Perform volume rendering + render_output = self.renderer(planes, self.decoder, ray_origins, + ray_directions, self.rendering_kwargs, apply_def = apply_def, ws = ws, pose_params = pose_params ) # channels last + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] + depth_samples = render_output['depth_final'] + weights_samples = render_output['weights'] + + + # Reshape into 'raw' neural-rendered image + + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + if self.decoder.activation == "sigmoid": + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + # Generate Background + if self.bcg_synthesis and render_bg: + ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws] + if ws_bcg.size(1) < self.bcg_synthesis.num_ws: + ws_bcg = torch.cat([ws_bcg, ws_bcg[:,-1:].repeat(1,self.bcg_synthesis.num_ws-ws_bcg.size(1),1)], 1) + bcg_image = self.bcg_synthesis(ws_bcg, update_emas=update_emas, **synthesis_kwargs) + bcg_image = torch.nn.functional.interpolate(bcg_image, size=feature_image.shape[2:], + mode='bilinear', align_corners=False, antialias=self.rendering_kwargs['sr_antialias']) + feature_image = feature_image + (1-weights_samples) * bcg_image + + # Generate Raw image + if self.torgb: + rgb_image = self.torgb(feature_image, ws[:,-1], fused_modconv=False) + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + else: + rgb_image = feature_image[:, :3] + + + mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 + + return {'image_raw': rgb_image, 'image_depth': depth_image, "image_mask": mask_image,'patch_info':patch_info} + + + def sample_trigrid(self, coordinates, directions, planes, update_emas=False, **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + if len(planes[res_k].shape) == 4: + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + + + +from training.networks_stylegan2 import FullyConnectedLayer + +class OSGDecoder(torch.nn.Module): + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 32 + + self.net = torch.nn.Sequential( + FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), + torch.nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']) + ) + self.activation = options['decoder_activation'] + + def forward(self, sampled_features, ray_directions): + # Aggregate features + sampled_features = sampled_features.mean(1) + x = sampled_features + + N, M, C = x.shape + x = x.view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = x[..., 1:] + sigma = x[..., 0:1] + if self.activation == "sigmoid": + # Original EG3D + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 + elif self.activation == "lrelu": + # StyleGAN2-style, use with toRGB + rgb = torch.nn.functional.leaky_relu(rgb, 0.2, inplace=True) * math.sqrt(2) + return {'rgb': rgb, 'sigma': sigma} + diff --git a/3DPortraitGAN_pyramid/training/smpl_triplane.py b/3DPortraitGAN_pyramid/training/smpl_triplane.py new file mode 100644 index 0000000..bd4ca20 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/smpl_triplane.py @@ -0,0 +1,492 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch +from torch_utils import persistence +from training.networks_stylegan2 import ToRGBLayer, SynthesisNetwork + +from training.networks_stylegan2 import Hierarchy3DAwareGenerator as StyleGAN2Backbone +from training.volumetric_rendering.renderer import ImportanceRenderer +from training.volumetric_rendering.ray_sampler import RaySampler +import dnnlib + +""" +Mask guidance, background synthesis and tri-grid representation from the paper +"PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360°" +https://github.com/SizheAn/PanoHead/blob/main/training/triplane.py +""" + +@persistence.persistent_class +class TriPlaneGenerator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + sr_num_fp16_res = 0, + mapping_kwargs = {}, # Arguments for MappingNetwork. + rendering_kwargs = {}, + sr_kwargs = {}, + batch_size=1, + explicitly_symmetry=False, + thickness= 0.05, + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + bcg_synthesis_kwargs = synthesis_kwargs.copy() + bcg_synthesis_kwargs["channel_base"] = 32768 + bcg_synthesis_kwargs["channel_max"] = 512 + + self.z_dim=z_dim + self.c_dim=c_dim + self.w_dim=w_dim + self.img_resolution=img_resolution + self.img_channels=img_channels + + self.trigrid_channel = 12 + self.decode_channel = 32 + + self.batch_size = batch_size + self.renderer = ImportanceRenderer(w_dim = w_dim, num_ws = 14, batch_size = self.batch_size,thickness =thickness,box_warp = rendering_kwargs['box_warp']) + self.ray_sampler = RaySampler() + # self.backbone = StyleGAN2Backbone(z_dim, c_dim+6, w_dim, img_resolution=512, img_channels=self.trigrid_channel*3*rendering_kwargs['triplane_depth'], mapping_kwargs=mapping_kwargs, **synthesis_kwargs) + self.backbone = StyleGAN2Backbone(z_dim, c_dim + 6, w_dim, img_resolution=256, + img_channels=self.trigrid_channel * 3 * rendering_kwargs['triplane_depth'], + mapping_kwargs=mapping_kwargs, roll_out=None, + **synthesis_kwargs) # forbid roll_out in main G + + self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=self.decode_channel, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) + self.decoder = OSGDecoder(self.trigrid_channel, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': self.decode_channel, + 'decoder_activation': rendering_kwargs['decoder_activation']}) + + self.torgb = ToRGBLayer(self.decode_channel, 3, w_dim) if rendering_kwargs.get('use_torgb_raw', False) else None + + self.bcg_synthesis = SynthesisNetwork(w_dim, img_resolution=self.superresolution.input_resolution, + img_channels=self.decode_channel, **bcg_synthesis_kwargs) if rendering_kwargs.get('use_background', False) else None + + self.pose_branch = GPoseBranch(z_dim = z_dim, c_dim = c_dim) + self.neural_rendering_resolution = 64 + self.rendering_kwargs = rendering_kwargs + + self._last_planes = None + + self.explicitly_symmetry = explicitly_symmetry + + self.avg_c = torch.tensor([[ 1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().cuda() + + self.plane_shapes = {} + + planes = self.backbone.synthesis(torch.zeros(4,self.backbone.synthesis.num_ws,w_dim), update_emas=False, **synthesis_kwargs) + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + if res_k not in self.plane_shapes: + self.plane_shapes[res_k] = planes[res_k].shape + + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + + def get_pose_params(self, z, c): + if self.explicitly_symmetry: + # check if c is a left face + theta = torch.atan2(c[:, [11]], c[:, [3]]) # math.atan2(z, x) + is_left = (theta >= -np.pi / 2) & (theta <= np.pi / 2) + + + flip_c = self.flip_yaw(c) + input_c = torch.where(is_left, flip_c, c) # if left, flip c + + pose_params = self.pose_branch(z, input_c) + + flip_pose_params = pose_params.clone() + flip_pose_params[:, [1, 2, 4, 5]] *= -1 # flip y and z axis angles + + pose_params = torch.where(is_left, flip_pose_params, pose_params) # if left, flip back pose_params + + return pose_params + else: + raise NotImplementedError + return self.pose_branch(z, c) + + def set_batch_size(self, batch_size): + self.renderer.set_batch_size(batch_size) + + def render_meshes(self,shape_pose_params,resolution,cameras): + + return self.renderer.render_meshes(shape_pose_params, resolution, cameras) + + + def mapping(self, z, c, p, truncation_psi=1, truncation_cutoff=None, update_emas=False): + if self.rendering_kwargs['c_gen_conditioning_zero']: + raise NotImplementedError + p = torch.zeros([c.shape[0], 6]).to(c.device) + c = self.avg_c.repeat(c.shape[0], 1).to(c.device) + c = torch.cat([c, p], dim=1) + + else: + + if p is None: + p = torch.zeros([c.shape[0],6]).to(c.device) + c = torch.cat([c,p],dim=1) + return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + + + def synthesis(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, + apply_def=False, pose_params = None,ws_bcg=None, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + intrinsics = c[:, 16:25].view(-1, 3, 3) + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) + + # Create triplanes by running StyleGAN backbone + N, M, _ = ray_origins.shape + if use_cached_backbone and self._last_planes is not None: + planes = self._last_planes + else: + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + + if cache_backbone: + self._last_planes = planes + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + planes[res_k] = planes[res_k].view(N, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + + # Perform volume rendering + render_output = self.renderer(planes, self.decoder, ray_origins, + ray_directions, self.rendering_kwargs, apply_def = apply_def, ws = ws, pose_params = pose_params ) # channels last + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] + depth_samples = render_output['depth_final'] + weights_samples = render_output['weights'] + + + # Reshape into 'raw' neural-rendered image + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run superresolution to get final image + if self.decoder.activation == "sigmoid": + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + # Generate Background + if self.bcg_synthesis: + ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws] + if ws_bcg.size(1) < self.bcg_synthesis.num_ws: + ws_bcg = torch.cat([ws_bcg, ws_bcg[:,-1:].repeat(1,self.bcg_synthesis.num_ws-ws_bcg.size(1),1)], 1) + bcg_image = self.bcg_synthesis(ws_bcg, update_emas=update_emas, **synthesis_kwargs) + bcg_image = torch.nn.functional.interpolate(bcg_image, size=feature_image.shape[2:], + mode='bilinear', align_corners=False, antialias=self.rendering_kwargs['sr_antialias']) + feature_image = feature_image + (1-weights_samples) * bcg_image + + # Generate Raw image + if self.torgb: + rgb_image = self.torgb(feature_image, ws[:,-1], fused_modconv=False) + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + bcg_rgb_image = self.torgb(bcg_image, ws_bcg[:,-1], fused_modconv=False) + bcg_rgb_image = bcg_rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + else: + rgb_image = feature_image[:, :3] + bcg_rgb_image = bcg_image[:, :3] + # Run superresolution to get final image + sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) + + mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 + + return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image, "image_mask": mask_image, "image_background":bcg_rgb_image} + + def gen_planes(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, + use_cached_backbone=False, + apply_def=False, pose_params=None, ws_bcg=None, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + intrinsics = c[:, 16:25].view(-1, 3, 3) + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) + + # Create triplanes by running StyleGAN backbone + N, M, _ = ray_origins.shape + if use_cached_backbone and self._last_planes is not None: + planes = self._last_planes + else: + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + + if cache_backbone: + self._last_planes = planes + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + return planes,ws + + def render_planes(self, ws, planes, c, neural_rendering_resolution=None, update_emas=False, chunk = None, + apply_def=False, pose_params = None,ws_bcg=None, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + intrinsics = c[:, 16:25].view(-1, 3, 3) + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) + + # Create triplanes by running StyleGAN backbone + N, M, _ = ray_origins.shape + + + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + if len(planes[res_k].shape) == 4: + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + if chunk is not None: + + feature_list, depth_list, weight_list = list(), list(), list() + for _ro, _rd in zip(torch.split(ray_origins, chunk, dim=1), torch.split(ray_directions, chunk, dim=1)): + render_output = self.renderer(planes, self.decoder, _ro, + _rd, self.rendering_kwargs, apply_def = apply_def, ws = ws, pose_params = pose_params ) # channels last + + _f = render_output['rgb_final'] + _d = render_output['depth_final'] + _w = render_output['weights'] + feature_list.append(_f) + depth_list.append(_d) + weight_list.append(_w) + feature_samples = torch.cat(feature_list, 1) + depth_samples = torch.cat(depth_list, 1) + weights_samples = torch.cat(weight_list, 1) + else: + + # Perform volume rendering + render_output = self.renderer(planes, self.decoder, ray_origins, + ray_directions, self.rendering_kwargs, apply_def = apply_def, ws = ws, pose_params = pose_params ) # channels last + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] + depth_samples = render_output['depth_final'] + weights_samples = render_output['weights'] + + + # Reshape into 'raw' neural-rendered image + H = W = neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run superresolution to get final image + if self.decoder.activation == "sigmoid": + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + # Generate Background + if self.bcg_synthesis: + ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws] + if ws_bcg.size(1) < self.bcg_synthesis.num_ws: + ws_bcg = torch.cat([ws_bcg, ws_bcg[:,-1:].repeat(1,self.bcg_synthesis.num_ws-ws_bcg.size(1),1)], 1) + bcg_image = self.bcg_synthesis(ws_bcg, update_emas=update_emas, **synthesis_kwargs) + bcg_image = torch.nn.functional.interpolate(bcg_image, size=feature_image.shape[2:], + mode='bilinear', align_corners=False, antialias=self.rendering_kwargs['sr_antialias']) + feature_image = feature_image + (1-weights_samples) * bcg_image + + # Generate Raw image + if self.torgb: + rgb_image = self.torgb(feature_image, ws[:,-1], fused_modconv=False) + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + bcg_rgb_image = self.torgb(bcg_image, ws_bcg[:,-1], fused_modconv=False) + bcg_rgb_image = bcg_rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + else: + rgb_image = feature_image[:, :3] + bcg_rgb_image = bcg_image[:, :3] + # Run superresolution to get final image + + mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 + + return {'image_raw': rgb_image, 'image_depth': depth_image, "image_mask": mask_image, "image_background":bcg_rgb_image} + + def sample_trigrid(self, coordinates, directions, planes, update_emas=False, **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + if len(planes[res_k].shape) == 4: + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + + def sample_ws(self, coordinates, directions, ws, update_emas=False, **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def sample(self, coordinates, directions, z, c, p, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + ws = self.mapping(z, c, p,truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' + planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, + apply_def=False, pose_params=None, + **synthesis_kwargs): + # Render a batch of generated images. + ws = self.mapping(z, c, pose_params,truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + # TODO + return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, + apply_def=apply_def, pose_params = pose_params, + **synthesis_kwargs) + + +from training.networks_stylegan2 import FullyConnectedLayer + +class OSGDecoder(torch.nn.Module): + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 32 + + self.net = torch.nn.Sequential( + FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), + torch.nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']) + ) + self.activation = options['decoder_activation'] + + def forward(self, sampled_features, ray_directions): + # Aggregate features + sampled_features = sampled_features.mean(1) + x = sampled_features + + N, M, C = x.shape + x = x.view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = x[..., 1:] + sigma = x[..., 0:1] + if self.activation == "sigmoid": + # Original EG3D + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 + elif self.activation == "lrelu": + # StyleGAN2-style, use with toRGB + rgb = torch.nn.functional.leaky_relu(rgb, 0.2, inplace=True) * math.sqrt(2) + return {'rgb': rgb, 'sigma': sigma} + + +import numpy as np +class GPoseBranch(torch.nn.Module): + def __init__(self, z_dim, c_dim): + super().__init__() + hidden_dim = 64 + self.in_channel = z_dim + c_dim + # + # predict_betas = predict_transl = predict_scale = False + # predict_pose = True + + out_dim = 6 + + # if predict_betas: + # out_dim += num_betas + # if predict_transl: + # out_dim += 3 + # if predict_scale: + # out_dim += 1 + # if predict_pose: + # out_dim += 6 + + self.output_dim = out_dim + self.net = torch.nn.Sequential( + FullyConnectedLayer(self.in_channel, 128, activation='lrelu'), + FullyConnectedLayer(128, 32, activation='lrelu'), + FullyConnectedLayer(32, self.output_dim) + ) + + + def forward(self, z, c): + # misc.assert_shape(feature, [None, self.in_channel]) + # misc.assert_shape(camera_parameters, [None, 25]) + feature = torch.cat([z, c], dim=1) + + pose = self.net(feature) # (B, num_betas + 1 + 3 + 6) + + + return pose \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/superresolution.py b/3DPortraitGAN_pyramid/training/superresolution.py new file mode 100644 index 0000000..43321df --- /dev/null +++ b/3DPortraitGAN_pyramid/training/superresolution.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Superresolution network architectures from the paper +"Efficient Geometry-aware 3D Generative Adversarial Networks".""" + +import torch +from training.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer +from torch_utils.ops import upfirdn2d +from torch_utils import persistence +from torch_utils import misc + +from training.networks_stylegan2 import SynthesisBlock +import numpy as np +from training.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer + + +#---------------------------------------------------------------------------- + +# for 512x512 generation +@persistence.persistent_class +class SuperresolutionHybrid8X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 128, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 256x256 generation +@persistence.persistent_class +class SuperresolutionHybrid4X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + self.sr_antialias = sr_antialias + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 128 x 128 generation +@persistence.persistent_class +class SuperresolutionHybrid2X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 128 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 64 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=64, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=128, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# TODO: Delete (here for backwards compatibility with old 256x256 models) +@persistence.persistent_class +class SuperresolutionHybridDeepfp32(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlockNoUp(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + # if img is not None: + # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + # img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +#---------------------------------------------------------------------------- + +# for 512x512 generation +@persistence.persistent_class +class SuperresolutionHybrid8XDC(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 256, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(256, 128, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/training_loop.py b/3DPortraitGAN_pyramid/training/training_loop.py new file mode 100644 index 0000000..681de57 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/training_loop.py @@ -0,0 +1,714 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Main training loop.""" + +import os +import random +import time +import copy +import json +import pickle +import psutil +import PIL.Image +import numpy as np +import torch +import dnnlib +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +import legacy +from metrics import metric_main,metric_utils +from camera_utils import LookAtPoseSampler +from training.crosssection_utils import sample_cross_section + +#---------------------------------------------------------------------------- + +def setup_snapshot_image_grid(training_set, random_seed=0): + rnd = np.random.RandomState(random_seed) + h = int(7680 * (training_set.image_shape[2]/512)) + w = int(4320 * (training_set.image_shape[2] / 512)) + gh = np.clip(h // training_set.image_shape[2], 7, 8) + gw = np.clip(w // training_set.image_shape[1], 4, 4) + + # No labels => show random subset of training samples. + # if not training_set.has_labels: + # all_indices = list(range(len(training_set))) + # rnd.shuffle(all_indices) + # grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] + + # else: + # # Group training samples by label. + # label_groups = dict() # label => [idx, ...] + # for idx in range(len(training_set)): + # label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + # if label not in label_groups: + # label_groups[label] = [] + # label_groups[label].append(idx) + + # # Reorder. + # label_order = list(label_groups.keys()) + # rnd.shuffle(label_order) + # for label in label_order: + # rnd.shuffle(label_groups[label]) + + # # Organize into grid. + # grid_indices = [] + # for y in range(gh): + # label = label_order[y % len(label_order)] + # indices = label_groups[label] + # grid_indices += [indices[x % len(indices)] for x in range(gw)] + # label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + label_groups = dict() # label => [idx, ...] + for idx in range(len(training_set)): + label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + if label not in label_groups: + label_groups[label] = [] + label_groups[label].append(idx) + + # Reorder. + label_order = list(label_groups.keys()) + rnd.shuffle(label_order) + for label in label_order: + rnd.shuffle(label_groups[label]) + + # Organize into grid. + grid_indices = [] + for y in range(gh): + for x in range(gw//2): + label = label_order[(y + x*gh) % len(label_order)] + indices = list(set(label_groups[label])) + #grid_indices += [indices[x % len(indices)] for x in range(2)] + grid_indices += [indices[0], (indices[0]+ len(training_set)//2)%len(training_set) ] + label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + + + # Load data. + images, segs, labels, poses = zip(*[training_set[i] for i in grid_indices]) + return (gw, gh), np.stack(images),np.stack(segs), np.stack(labels), np.stack(poses) + +#---------------------------------------------------------------------------- + +def save_image_grid(img, fname, drange, grid_size): + lo, hi = drange + img = np.asarray(img, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = np.rint(img).clip(0, 255).astype(np.uint8) + + gw, gh = grid_size + _N, C, H, W = img.shape + img = img.reshape([gh, gw, C, H, W]) + img = img.transpose(0, 3, 1, 4, 2) + img = img.reshape([gh * H, gw * W, C]) + + assert C in [1, 3] + if C == 1: + PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) + if C == 3: + PIL.Image.fromarray(img, 'RGB').save(fname) + +#---------------------------------------------------------------------------- + +def training_loop( + run_dir = '.', # Output directory. + training_set_kwargs = {}, # Options for training set. + data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. + G_kwargs = {}, # Options for generator network. + D_kwargs = {}, # Options for discriminator network. + G_opt_kwargs = {}, # Options for generator optimizer. + D_opt_kwargs = {}, # Options for discriminator optimizer. + augment_kwargs = None, # Options for augmentation pipeline. None = disable. + loss_kwargs = {}, # Options for loss function. + metrics = [], # Metrics to evaluate during training. + random_seed = 0, # Global random seed. + num_gpus = 1, # Number of GPUs participating in the training. + rank = 0, # Rank of the current process in [0, num_gpus[. + batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. + batch_gpu = 4, # Number of samples processed at a time by one GPU. + ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. + ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. + G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. + D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. + augment_p = 0, # Initial value of augmentation probability. + ada_target = None, # ADA target value. None = fixed p. + ada_interval = 4, # How often to perform ADA adjustment? + ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. + total_kimg = 25000, # Total length of the training, measured in thousands of real images. + kimg_per_tick = 4, # Progress snapshot interval. + image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. + network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. + resume_pkl = None, # Network pickle to resume training from. + resume_kimg = 0, # First kimg to report when resuming training. + cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. + progress_fn = None, # Callback function for updating training progress. Called for all ranks. + train_g_pose_branch = None, + metric_pose_sample_mode = None, +): + print('Random seed: %d' % random_seed) + # Initialize. + start_time = time.time() + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.cuda.set_device(device) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. # TODO: ENABLE + grid_sample_gradfix.enabled = False # Avoids errors with the augmentation pipe. + + # Load training set. + if rank == 0: + print('Loading training set...') + training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset + training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) + training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) + if rank == 0: + print() + print('Num images: ', len(training_set)) + print('Image shape:', training_set.image_shape) + print('Label shape:', training_set.label_shape) + print('Pose shape:', training_set.pose_shape) + print() + print('>>>>>>>>>>>>>>> image_snapshot_ticks:', image_snapshot_ticks) + print('>>>>>>>>>>>>>>> network_snapshot_ticks:', network_snapshot_ticks) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) + G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G.register_buffer('dataset_label_std', torch.tensor(training_set.get_label_std()).to(device)) + D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G_ema = copy.deepcopy(G).eval() + D_ema = copy.deepcopy(D).eval() + + # Resume from existing pickle. + if (resume_pkl is not None) and (rank == 0): + print(f'Resuming from "{resume_pkl}"') + with dnnlib.util.open_url(resume_pkl) as f: + resume_data = legacy.load_network_pkl(f) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: + misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + + if 'D_ema' in resume_data: + print(f'copy params of D_ema of "{resume_pkl} to D_ema') + misc.copy_params_and_buffers(resume_data['D_ema'], D_ema, require_all=False) + else: + print(f'copy params of D of "{resume_pkl} to D_ema') + misc.copy_params_and_buffers(resume_data['D'], D_ema, require_all=False) + + # Print network summary tables. + if rank == 0: + z = torch.empty([batch_gpu, G.z_dim], device=device) + c = torch.empty([batch_gpu, G.c_dim], device=device) + p = torch.empty([batch_gpu, 6], device=device) + img = misc.print_module_summary(G, [z, c, ]) + misc.print_module_summary(D, [img, c ]) + + print('plane_shapes:') + for res_k in G.plane_shapes: + print(res_k, G.plane_shapes[res_k]) + # Setup augmentation. + if rank == 0: + print('Setting up augmentation...') + augment_pipe = None + ada_stats = None + if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): + augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + augment_pipe.p.copy_(torch.as_tensor(augment_p)) + if ada_target is not None: + ada_stats = training_stats.Collector(regex='Loss/signs/real') + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [G, D, G_ema,D_ema, augment_pipe]: + if module is not None: + for param in misc.params_and_buffers(module): + if param.numel() > 0 and num_gpus > 1: + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe,rank = rank,**loss_kwargs) # subclass of training.loss.Loss + phases = [] + for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: + params_list = [] + params_name_list = [] + for p_name, p in module.named_parameters(): + if name == 'G': + if 'aligned_SMPL' not in p_name: + if not train_g_pose_branch: + if 'pose_branch' not in p_name: + params_list.append(p) + params_name_list.append(p_name) + else: + params_list.append(p) + params_name_list.append(p_name) + else: + params_list.append(p) + params_name_list.append(p_name) + + + + if rank ==0: + print(f'params_name_list of {name}:',params_name_list) + + if reg_interval is None: + opt = dnnlib.util.construct_class_by_name(params=params_list, **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] + + + else: # Lazy regularization. + mb_ratio = reg_interval / (reg_interval + 1) + opt_kwargs = dnnlib.EasyDict(opt_kwargs) + opt_kwargs.lr = opt_kwargs.lr * mb_ratio + opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] + opt = dnnlib.util.construct_class_by_name(params=params_list, **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] + phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] + + + + for phase in phases: + phase.start_event = None + phase.end_event = None + if rank == 0: + phase.start_event = torch.cuda.Event(enable_timing=True) + phase.end_event = torch.cuda.Event(enable_timing=True) + print('phase: ',phase.name) + + # Export sample images. + grid_size = None + grid_z = None + grid_c = None + if rank == 0: + print('Exporting sample images...') + grid_size, images,segs, labels,poses = setup_snapshot_image_grid(training_set=training_set,random_seed=random.randint(0, 1000000)) + save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + save_image_grid(segs, os.path.join(run_dir, 'segs.jpg'), drange=[0, 255], grid_size=grid_size) + grid_images = (torch.from_numpy(images).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + grid_segs = (torch.from_numpy(segs).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + + #grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + + if G.rendering_kwargs['c_gen_conditioning_zero']: + raise NotImplementedError + grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + else: + #raise NotImplementedError + grid_z = [] + for i in range(labels.shape[0]//2): + sample_z = torch.randn([1, G.z_dim], device=device) + grid_z.append(sample_z) + grid_z.append(sample_z) + grid_z = torch.cat(grid_z,dim=0).split(batch_gpu) + + + grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) + grid_poses = torch.from_numpy(poses).to(device).split(batch_gpu) + + real_shape_real_pose = [] + for real_pose, c in zip(grid_poses, grid_c): + real_shape_pose_param = {'pose': real_pose} + real_shape_real_pose.append( + G_ema.render_meshes(real_shape_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + real_shape_real_pose = np.concatenate(real_shape_real_pose, axis=0) + save_image_grid(real_shape_real_pose, + os.path.join(run_dir, f'mesh_coarse_real_pose.png'), + drange=[0, 255], grid_size=grid_size) + #exit() + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + stats_collector = training_stats.Collector(regex='.*') + stats_metrics = dict() + stats_jsonl = None + stats_tfevents = None + if rank == 0: + stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') + try: + import torch.utils.tensorboard as tensorboard + stats_tfevents = tensorboard.SummaryWriter(run_dir) + except ImportError as err: + print('Skipping tfevents export:', err) + + # Train. + if rank == 0: + print(f'Training for {total_kimg} kimg...') + print() + cur_nimg = resume_kimg * 1000 + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + batch_idx = 0 + if progress_fn is not None: + progress_fn(0, total_kimg) + + + + while True: + # Fetch training data. + with torch.autograd.profiler.record_function('data_fetch'): + + phase_real_img, phase_real_seg, phase_real_c, phase_real_pose = next(training_set_iterator) + + + phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + phase_real_seg = (phase_real_seg.to(device).to(torch.float32) / 255.0).split(batch_gpu) + phase_real_c = phase_real_c.to(device).split(batch_gpu) + phase_real_pose = phase_real_pose.to(device).split(batch_gpu) + + all_gen_z = torch.randn([len(phases) * (batch_size // num_gpus), G.z_dim], device=device) # 4 * 8 + all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split((batch_size // num_gpus))] + + random_idx = [np.random.randint(len(training_set)) for _ in range(len(phases) * (batch_size // num_gpus))] + + + all_gen_c = [training_set.get_label(gen_idx) for gen_idx in random_idx] + all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) + all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split((batch_size // num_gpus))] + + + all_gen_pose = [training_set.get_coarse_pose(gen_idx) for gen_idx in random_idx] + all_gen_pose = torch.from_numpy(np.stack(all_gen_pose)).pin_memory().to(device) + all_gen_pose = [phase_gen_pose.split(batch_gpu) for phase_gen_pose in all_gen_pose.split((batch_size // num_gpus))] + + assert len(phases) == len(all_gen_z) == len(all_gen_c) ==len(all_gen_pose) + # Execute training phases. + for phase, phase_gen_z,phase_gen_c,phase_gen_pose in zip(phases, all_gen_z,all_gen_c,all_gen_pose): # 4 + if batch_idx % phase.interval != 0: + continue + + + if phase.start_event is not None: + phase.start_event.record(torch.cuda.current_stream(device)) + + # Accumulate gradients. + phase.opt.zero_grad(set_to_none=True) + phase.module.requires_grad_(True) + for real_img, real_seg, real_c,real_pose, gen_z,gen_c,gen_pose in \ + zip(phase_real_img, phase_real_seg, phase_real_c, phase_real_pose, phase_gen_z,phase_gen_c,phase_gen_pose): + + loss.accumulate_gradients(phase=phase.name, real_img=real_img,real_seg = real_seg, real_c=real_c,real_pose = real_pose, + gen_z=gen_z,gen_c = gen_c, gen_pose = gen_pose, + + gain=phase.interval, cur_nimg=cur_nimg,cur_nimg_start = resume_kimg * 1000) + phase.module.requires_grad_(False) + + # Update weights. + with torch.autograd.profiler.record_function(phase.name + '_opt'): + + params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None] + if len(params) > 0: + flat = torch.cat([param.grad.flatten() for param in params]) + if num_gpus > 1: + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + phase.opt.step() + + + + # Phase done. + if phase.end_event is not None: + phase.end_event.record(torch.cuda.current_stream(device)) + + # Update G_ema. + with torch.autograd.profiler.record_function('Gema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(G_ema.parameters(), G.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(G_ema.buffers(), G.buffers()): + b_ema.copy_(b) + G_ema.neural_rendering_resolution = G.neural_rendering_resolution + G_ema.rendering_kwargs = G.rendering_kwargs.copy() + + with torch.autograd.profiler.record_function('Dema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(D_ema.parameters(), D.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(D_ema.buffers(), D.buffers()): + b_ema.copy_(b) + + + # Update state. + cur_nimg += batch_size + batch_idx += 1 + + # Execute ADA heuristic. + if (ada_stats is not None) and (batch_idx % ada_interval == 0): + ada_stats.update() + adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) + augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) + + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): + continue + + # Print status line, accumulating the same information in training_stats. + tick_end_time = time.time() + fields = [] + fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] + fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] + fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] + fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] + fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] + fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] + fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] + fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] + torch.cuda.reset_peak_memory_stats() + fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] + + if loss.swapping_prob is not None: + fields += [f"swap prob {training_stats.report0('Progress/swap_prob', float(loss.swapping_prob)):.3f}"] + if loss.neural_rendering_resolution is not None: + fields += [f"render_res {training_stats.report0('Progress/rendering_res', float(loss.neural_rendering_resolution)):.3f}"] + # if loss.noise_alpha is not None: + # fields += [f"noise_alpha {training_stats.report0('Progress/noise_alpha', float(loss.noise_alpha)):.3f}"] + # if loss.noise_scale is not None: + # fields += [f"noise_scale {training_stats.report0('Progress/noise_scale', float(loss.noise_scale)):.3f}"] + + # if loss.predict_label_alpha is not None: + # fields += [f"predict_label_alpha {training_stats.report0('Progress/predict_label_alpha', float(loss.predict_label_alpha)):.3f}"] + + training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) + training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) + if rank == 0: + print(' '.join(fields)) + + # Check for abort. + if (not done) and (abort_fn is not None) and abort_fn(): + done = True + if rank == 0: + print() + print('Aborting...') + + + + + if (rank == 0) and ((image_snapshot_ticks is not None) and (done or (cur_tick % image_snapshot_ticks == 0) ) ): # or (cur_tick<50 and cur_tick % 5 == 0 ) ) # (cur_tick!=0) and + print('gen images...') + with torch.no_grad(): + predicted_real_pose_params_D = [] + for vis_real_img,vis_real_seg, vis_c in zip(grid_images,grid_segs, grid_c): + pose_param = loss.get_pose_params_D(vis_real_img,vis_real_seg, vis_c, cur_nimg) + predicted_real_pose_params_D.append(pose_param) + + predicted_fake_pose_params_G = [] + for vis_z, vis_c in zip(grid_z, grid_c): + pose_param = loss.get_pose_params_G(vis_z, vis_c) + predicted_fake_pose_params_G.append(pose_param) + + + real_pose_mesh = [] + for predicted_real_pose, c in zip(predicted_real_pose_params_D, grid_c): + real_pose_param = {'pose': predicted_real_pose} + real_pose_mesh.append( + G_ema.render_meshes(real_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + real_pose_mesh = np.concatenate(real_pose_mesh, axis=0) + save_image_grid(real_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_real_pose_D.png'), + drange=[0, 255], grid_size=grid_size) + + + snap_pose = predicted_fake_pose_params_G + cond_c = torch.tensor([[ 1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().to(device) + + + #out = [G_ema(z=z, c=c, noise_mode='const',apply_def = True, pose_params = pose) for z, c, pose in zip(grid_z, grid_c, snap_pose)] + grid_ws = [G_ema.mapping(z, cond_c.expand(z.shape[0], -1),None) for z in grid_z] + out =[G_ema.synthesis(ws, c=c, noise_mode='const',apply_def = True, pose_params = pose) for ws, c,pose in zip(grid_ws, grid_c,snap_pose)] + images = torch.cat([o['image'].cpu() for o in out]).numpy() + #print('images range: ',np.max(images),np.min(images)) + images_raw = torch.cat([o['image_raw'].cpu() for o in out]).numpy() + images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy() + images_alpha = torch.cat([o['image_mask'].cpu() for o in out]).numpy() + #background_raw = torch.cat([o['image_background'].cpu() for o in out]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_0.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_2_raw.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_4_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size) + save_image_grid(images_alpha, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_4_alpha.jpg'), drange=[0, 1], grid_size=grid_size) + #save_image_grid(background_raw, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_4_background.jpg'), drange=[-1, 1], grid_size=grid_size) + with torch.no_grad(): + predicted_fake_pose_params_D = [] + for o,vis_c,vis_pose in zip(out,grid_c,snap_pose): + pose_param = loss.get_pose_params_D(o['image'],o['image_mask'],vis_c, cur_nimg) + predicted_fake_pose_params_D.append(pose_param) + + fake_pose_mesh = [] + for predicted_fake_pose, c in zip(predicted_fake_pose_params_D, grid_c): + fake_pose_param = {'pose': predicted_fake_pose} + fake_pose_mesh.append( + G_ema.render_meshes(fake_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + fake_pose_mesh = np.concatenate(fake_pose_mesh, axis=0) + save_image_grid(fake_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_fake_pose_D.png'), + drange=[0, 255], grid_size=grid_size) + + input_pose_mesh = [] + for input_pose, c in zip(predicted_fake_pose_params_G, grid_c): + input_pose_param = {'pose': input_pose} + input_pose_mesh.append( + G_ema.render_meshes(input_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + input_pose_mesh = np.concatenate(input_pose_mesh, axis=0) + save_image_grid(input_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_input_pose_G.png'), + drange=[0, 255], grid_size=grid_size) + + + + + # no_pose_out = [G_ema(z=z, c=c, noise_mode='const', apply_def=False, pose_params=None) for z, c in zip(grid_z, grid_c)] + no_pose_out =[G_ema.synthesis(ws, c=c, noise_mode='const',apply_def = False, pose_params = None) for ws, c in zip(grid_ws, grid_c)] + images = torch.cat([o['image'].cpu() for o in no_pose_out]).numpy() + images_raw = torch.cat([o['image_raw'].cpu() for o in no_pose_out]).numpy() + images_depth = -torch.cat([o['image_depth'].cpu() for o in no_pose_out]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_1_no_pose.png'), drange=[-1, 1], + grid_size=grid_size) + save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_3_no_pose_raw.png'), drange=[-1, 1], + grid_size=grid_size) + save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_5_no_pose_depth.png'), + drange=[images_depth.min(), images_depth.max()], grid_size=grid_size) + + + + # if (loss.fronzen_D is not None) and ((network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0)): + # if rank ==0 : + # print('update loss.fronzen_D...') + # misc.copy_params_and_buffers(D, loss.fronzen_D, require_all=True) + # Save network snapshot. + snapshot_pkl = None + snapshot_data = None + if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): + snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('D_ema', D_ema), ('augment_pipe', augment_pipe)]: + if module is not None: + if num_gpus > 1: + misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)') + module = copy.deepcopy(module).eval().requires_grad_(False).cpu() + snapshot_data[name] = module + del module # conserve memory + snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') + if rank == 0: + with open(snapshot_pkl, 'wb') as f: + pickle.dump(snapshot_data, f) + + pose_predict_kwargs = { + 'blur_sigma' : loss.blur_sigma, + 'neural_rendering_resolution': loss.neural_rendering_resolution, + 'resample_filter': loss.resample_filter.cpu().numpy().tolist(), + 'filter_mode': loss.filter_mode + } + with open(os.path.join(run_dir, f'pose_predict_kwargs-{cur_nimg//1000:06d}.json'), 'wt') as f: + json.dump(pose_predict_kwargs, f, indent=2) + + + # Evaluate metrics. + if (cur_tick!=0) and (snapshot_data is not None) and (len(metrics) > 0): + if rank == 0: + print(run_dir) + print('Evaluating metrics...') + for metric in metrics: + progress = metric_utils.ProgressMonitor(verbose=True) + # result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], + # dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, + # rank=rank, device=device, progress=progress + # ) + result_dict = metric_main.calc_metric(metric=metric, + G=snapshot_data['G_ema'], + dataset_kwargs=training_set_kwargs, + num_gpus=num_gpus, + rank=rank, + device=device, + metric_pose_sample_mode = metric_pose_sample_mode, + progress=progress, + D = snapshot_data['D'] if metric_pose_sample_mode == 'D_predict' else None, + pose_predict_kwargs = { + 'blur_sigma' : loss.blur_sigma, + 'neural_rendering_resolution': loss.neural_rendering_resolution, + 'resample_filter': loss.resample_filter, + 'filter_mode': loss.filter_mode + } if metric_pose_sample_mode == 'D_predict' else None + ) + + if rank == 0: + metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) + stats_metrics.update(result_dict.results) + del snapshot_data # conserve memory + + # Collect statistics. + for phase in phases: + value = [] + if (phase.start_event is not None) and (phase.end_event is not None): + phase.end_event.synchronize() + value = phase.start_event.elapsed_time(phase.end_event) + training_stats.report0('Timing/' + phase.name, value) + stats_collector.update() + stats_dict = stats_collector.as_dict() + + # Update logs. + timestamp = time.time() + if stats_jsonl is not None: + fields = dict(stats_dict, timestamp=timestamp) + stats_jsonl.write(json.dumps(fields) + '\n') + stats_jsonl.flush() + if stats_tfevents is not None: + global_step = int(cur_nimg / 1e3) + walltime = timestamp - start_time + for name, value in stats_dict.items(): + stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) + for name, value in stats_metrics.items(): + stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) + stats_tfevents.flush() + if progress_fn is not None: + progress_fn(cur_nimg // 1000, total_kimg) + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + if rank == 0: + print() + print('Exiting...') + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/training/volumetric_rendering/__init__.py b/3DPortraitGAN_pyramid/training/volumetric_rendering/__init__.py new file mode 100644 index 0000000..daba665 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/volumetric_rendering/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/volumetric_rendering/math_utils.py b/3DPortraitGAN_pyramid/training/volumetric_rendering/math_utils.py new file mode 100644 index 0000000..4cf9d2b --- /dev/null +++ b/3DPortraitGAN_pyramid/training/volumetric_rendering/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_marcher.py b/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_marcher.py new file mode 100644 index 0000000..3c2d1ee --- /dev/null +++ b/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_marcher.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MipRayMarcher2(nn.Module): + def __init__(self): + super().__init__() + + def run_forward(self, colors, densities, depths, rendering_options): + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + + if rendering_options['clamp_mode'] == 'softplus': + densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better + else: + assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options): + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + + return composite_rgb, composite_depth, weights \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_sampler.py b/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_sampler.py new file mode 100644 index 0000000..80b10b0 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/volumetric_rendering/ray_sampler.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, resolution): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + resolution: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + N, M = cam2world_matrix.shape[0], resolution**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) + y_cam = uv[:, :, 1].view(N, -1) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + def patch_forward(self, cam2world_matrix, intrinsics, resolution, patch_scale=1): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + resolution: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + N, M = cam2world_matrix.shape[0], resolution ** 2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + full_resolution = int(resolution / patch_scale) + patch_info = [] + uv = torch.stack( + torch.meshgrid(torch.arange(full_resolution, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(full_resolution, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij')) * (1. / full_resolution) + (0.5 / full_resolution) + if full_resolution > resolution: + patch_uv = [] + for i in range(cam2world_matrix.shape[0]): + top = torch.randint(full_resolution - resolution + 1, ()).item() + left = torch.randint(full_resolution - resolution + 1, ()).item() + patch_uv.append(uv.clone()[None, :, top:top + resolution, left:left + resolution]) + patch_info.append((top, left)) + uv = torch.cat(patch_uv, 0) + else: + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1, 1) + uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1) + + x_cam = uv[:, :, 0].view(N, -1) + y_cam = uv[:, :, 1].view(N, -1) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze( + -1) * y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs, patch_info \ No newline at end of file diff --git a/3DPortraitGAN_pyramid/training/volumetric_rendering/renderer.py b/3DPortraitGAN_pyramid/training/volumetric_rendering/renderer.py new file mode 100644 index 0000000..cd79ac8 --- /dev/null +++ b/3DPortraitGAN_pyramid/training/volumetric_rendering/renderer.py @@ -0,0 +1,600 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import math +import torch +from torch_utils import misc +from training.volumetric_rendering.ray_marcher import MipRayMarcher2 +from training.volumetric_rendering import math_utils +# from training.aligned_smplx import AlignedSMPLX +from training.aligned_smpl import AlignedSMPL +import trimesh +#from training.aligned_smpl import AlignedSMPL +import smplx +from kaolin.ops.mesh import index_vertices_by_faces +from kaolin.metrics.trianglemesh import point_to_mesh_distance + + + + +# def generate_planes(): +# """ +# Defines planes by the three vectors that form the "axes" of the +# plane. Should work with arbitrary number of planes and planes of +# arbitrary orientation. +# """ +# return torch.tensor([[[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1]], +# [[1, 0, 0], +# [0, 0, 1], +# [0, 1, 0]], +# [[0, 0, 1], +# [1, 0, 0], +# [0, 1, 0]]], dtype=torch.float32) + +# correct tri-planes, see https://github.com/NVlabs/eg3d/issues/67 +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 1, 0], + [0, 0, 1], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None, triplane_depth=1,render_high_freq = True): + assert padding_mode == 'zeros' + output_features = None + + + _, M, _ = coordinates.shape + coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1).unsqueeze(2) # (N x n_planes) x 1 x 1 x M x 3 + for res_k in plane_features: + plane_feature = plane_features[res_k] + N, n_planes, CD, H, W = plane_feature.shape + # _, M, _ = coordinates.shape + C, D = CD // triplane_depth, triplane_depth + plane_feature = plane_feature.view(N * n_planes, C, D, H, W) + + # coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds + + # projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1).unsqueeze(2) # (N x n_planes) x 1 x 1 x M x 3 + output_feature = torch.nn.functional.grid_sample(plane_feature, projected_coordinates.float(), mode=mode, + padding_mode=padding_mode, align_corners=False).permute(0, + 4, + 3, + 2, + 1).reshape(N, n_planes, M, C) + if output_features is None: + output_features = output_feature + else: + output_features += output_feature + + output_features /= len(plane_features) + + return output_features +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +def triplane_crop_mask(xyz_unformatted, thresh, boxwarp, allow_bottom=True): + # bw,tc = boxwarp, thresh + bw = boxwarp + tc = boxwarp * thresh + device = xyz_unformatted.device + # xyz = 0.5 * (xyz_unformatted+1) * torch.tensor([-1,1,-1]).to(device)[None,None,:] + xyz = (xyz_unformatted) * torch.tensor([-1,1,-1]).to(device)[None,None,:] + ans = (xyz[:,:,[0,2]].abs() <= (bw/2-tc)).all(dim=-1,keepdim=True) + if allow_bottom: + ans = ans | ( + (xyz[:,:,1:2] <= -(bw/2-tc)) & + (xyz[:,:,[0,2]].abs() <= (bw/2-tc)).all(dim=-1,keepdim=True) + ) + return ~ans +def cull_clouds_mask(denities, thresh): + denities = torch.nn.functional.softplus(denities - 1) # activation bias of -1 makes things initialize better + alpha = 1 - torch.exp(-denities) + return alpha < thresh + + + +class ImportanceRenderer(torch.nn.Module): + def __init__(self, w_dim, num_ws,batch_size,thickness,box_warp): + super().__init__() + self.ray_marcher = MipRayMarcher2() + self.plane_axes = generate_planes() + self.batch_size = batch_size + self.num_betas = 10 + body_model_smpl = smplx.create('./smplx_models', + model_type='smpl', + gender='neutral', + use_compressed=False, + use_face_contour=True, + num_betas=self.num_betas, + num_expression_coeffs=10, + ext='npz', + batch_size = batch_size + ).cuda() + self.aligned_SMPL = AlignedSMPL(model=body_model_smpl,batch_size=batch_size) + + + + shaped_smpl_data = self.aligned_SMPL.generate_shaped_smpl( + betas=None, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'] + ) + shaped_smpl = shaped_smpl_data['vertices'].detach().contiguous() + align_points = shaped_smpl_data['align_joint_coordinate'].detach().contiguous() + + self.register_buffer('shaped_smpl', shaped_smpl) + self.register_buffer('align_points', align_points) + + # shaped_smpl [B,N,3] + # filter points that outside box + box_side_length = box_warp + # shaped_smpl: B,N,3 + point_mask = shaped_smpl[0:1,:,0] > -box_side_length/2 # 1,N + point_mask = point_mask & (shaped_smpl[0:1,:,0] < box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,1] > -box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,1] < box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,2] > -box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,2] < box_side_length/2) + point_mask = point_mask.squeeze(0).cuda() # N + + faces = self.aligned_SMPL.faces # [20908, 3] + face_mask = torch.ones(faces.shape[0],dtype=torch.bool).cuda() # [20908] + for i in range(faces.shape[0]): + face_mask[i] = point_mask[faces[i,0]] and point_mask[faces[i,1]] and point_mask[faces[i,2]] + self.register_buffer('face_mask', face_mask) + + self.thickness = thickness + + # shaped_smpl [B,N,3] + # filter points that not on the head + # shaped_smpl: B,N,3 + + # + # point_mask = shaped_smpl[0:1, :, 1] > 0 # 1,N + + point_mask = shaped_smpl[0:1, :, 1] > 0.06 # 1,N + point_mask = point_mask & (shaped_smpl[0:1, :, 2] < -0.0) + + point_mask = point_mask.squeeze(0).cuda() # N + + faces = self.aligned_SMPL.faces # [20908, 3] + head_face_mask = torch.ones(faces.shape[0], dtype=torch.bool).cuda() # [20908] + for i in range(faces.shape[0]): + head_face_mask[i] = point_mask[faces[i, 0]] and point_mask[faces[i, 1]] and point_mask[faces[i, 2]] + self.register_buffer('head_face_mask', head_face_mask) + + self.back_head_depth = None + # + # print('head_face_mask shape:',head_face_mask.shape) + + + def set_batch_size(self,batch_size): + self.batch_size = batch_size + body_model_smpl = smplx.create('./smplx_models', + model_type='smpl', + gender='neutral', + use_compressed=False, + use_face_contour=True, + num_betas=self.num_betas, + num_expression_coeffs=10, + ext='npz', + batch_size=batch_size + ).to(self.aligned_SMPL.model.shapedirs.device) + self.aligned_SMPL.set_model(body_model_smpl) + self.aligned_SMPL.set_batch_size(batch_size) + shaped_smpl_data = self.aligned_SMPL.generate_shaped_smpl( + betas=None, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'] + ) + shaped_smpl = shaped_smpl_data['vertices'].detach().contiguous() + align_points = shaped_smpl_data['align_joint_coordinate'].detach().contiguous() + self.register_buffer('shaped_smpl', shaped_smpl) + self.register_buffer('align_points', align_points) + + + def render_meshes(self, shape_pose_params,resolution,cameras): + images = self.aligned_SMPL.get_visualization(shape_pose_params, resolution, cameras) + return images + + + def get_deformed_coordinate(self, ws, pose_params, original_coordinate): + + + posed_smpl = self.aligned_SMPL.generate_posed_smpl(betas=None, + body_pose=pose_params, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'], + align_joint_coordinate=self.align_points)['vertices'] + # misc.assert_shape(posed_smpl, [None, 10475, 3]) + + + mode = 'kaolin' + if mode == 'pytorch3d': + raise NotImplementedError + import pytorch3d.ops + #raise NotImplementedError + with torch.no_grad(): + + smpl_def_on_mesh = self.shaped_smpl - posed_smpl # [B, , 3] + + # find the nearest face in posed_smpl for each vertex in original_coordinate + knn_res = pytorch3d.ops.knn_points(p1=original_coordinate, p2=posed_smpl, K=1) + distance = knn_res[0] # [B, N, 1] + p1_index = knn_res[1].repeat(1, 1, 3) # [B, N, 3] + misc.assert_shape(p1_index, [original_coordinate.shape[0], original_coordinate.shape[1],3]) + + + DistToMesh = distance.squeeze(-1) # [B, N] + + SmplDef = smpl_def_on_mesh.gather(1, p1_index) # [B, N, 3] + mask = DistToMesh < self.thickness# [B, N] + + + scale = 5. + SmplDef1 = SmplDef / torch.exp(DistToMesh.unsqueeze(-1) * scale) # [B, N, 3] + + scale = DistToMesh.unsqueeze(-1) / (self.thickness * 2) * 20 + SmplDef2 = torch.zeros_like(SmplDef).to(SmplDef.device) + + SmplDef = torch.where(mask.unsqueeze(-1), SmplDef1, SmplDef2) # [B, N, 3] + elif mode == 'kaolin': + faces = self.aligned_SMPL.faces.clone() # [20908, 3] + faces = faces[self.face_mask, :] + # find the nearest face in shaped_smplx for each vertex in original_coordinate + vertex_faces = posed_smpl.clone() # [B, 6085, 3] + + with torch.no_grad(): + face_vertices = index_vertices_by_faces(vertex_faces, faces) + distance, index, dist_type = point_to_mesh_distance(original_coordinate, face_vertices) # B, N + distance = torch.sqrt(distance) # [B, N, 1] + selected_posed_smpl_vertices = [] + selected_shaped_smpl_vertices = [] + + for i in range(original_coordinate.shape[0]): + selected_face = faces[index[i]] + selected_posed_smpl_vertices.append(index_vertices_by_faces(posed_smpl[i:i + 1], + selected_face)) # [1, N, 3, 3] + selected_shaped_smpl_vertices.append(index_vertices_by_faces(self.shaped_smpl[i:i + 1], + selected_face)) # [1, N, 3, 3] + + selected_posed_smpl_vertices = torch.cat(selected_posed_smpl_vertices, dim=0) # [B, N, 3, 3] + selected_shaped_smpl_vertices = torch.cat(selected_shaped_smpl_vertices, dim=0) # [B, N, 3, 3] + + y_axes = torch.cross(selected_posed_smpl_vertices[:, :, 1, :] - selected_posed_smpl_vertices[:, :, 0, :], + selected_posed_smpl_vertices[:, :, 2, :] - selected_posed_smpl_vertices[:, :, 0, + :]) # [B, N, 3] + y_axes = y_axes / torch.norm(y_axes, dim=2, keepdim=True) # [B, N, 3] + + x_axes = selected_posed_smpl_vertices[:, :, 1, :] - selected_posed_smpl_vertices[:, :, 0, :] # [B, N, 3] + x_axes = x_axes / torch.norm(x_axes, dim=2, keepdim=True) # [B, N, 3] + + z_axes = torch.cross(x_axes, y_axes) # [B, N, 3] + + posed_smpl_coordinate = torch.stack( + [torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * x_axes, dim=2), + torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * y_axes, dim=2), + torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * z_axes, dim=2)], + dim=2) # [B, N, 3] + del x_axes, y_axes, z_axes + y_axes = torch.cross(selected_shaped_smpl_vertices[:, :, 1, :] - selected_shaped_smpl_vertices[:, :, 0, :], + selected_shaped_smpl_vertices[:, :, 2, :] - selected_shaped_smpl_vertices[:, :, 0, :]) + y_axes = y_axes / torch.norm(y_axes, dim=2, keepdim=True) + + x_axes = selected_shaped_smpl_vertices[:, :, 1, :] - selected_shaped_smpl_vertices[:, :, 0, :] + x_axes = x_axes / torch.norm(x_axes, dim=2, keepdim=True) + + z_axes = torch.cross(x_axes, y_axes) + + new_coordinate = posed_smpl_coordinate[:, :, 0:1] * x_axes + \ + posed_smpl_coordinate[:, :, 1:2] * y_axes + \ + posed_smpl_coordinate[:, :, 2:3] * z_axes + \ + selected_shaped_smpl_vertices[:, :, 0, :] # [B, N, 3] + + SmplDef = new_coordinate - original_coordinate # [B, N, 3] + + DistToMesh = distance.unsqueeze(-1) # [B, N, 1] + + mask = DistToMesh < self.thickness # [B, N,1] + + SmplDef2 = torch.zeros_like(SmplDef).to(SmplDef.device) + SmplDef = torch.where(mask, SmplDef, SmplDef2) # [B, N, 3] + + else: + raise NotImplementedError + + original_coordinate = original_coordinate + SmplDef + return original_coordinate + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, apply_def = False, ws = None, pose_params = None, triplane_crop=0.1, cull_clouds=None, binarize_clouds=None ): + _ = ws + if apply_def: + assert pose_params is not None + else: + assert pose_params is None + + self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape + + # Coarse Pass + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + # deform the sample_coordinates + if apply_def: + sample_coordinates = self.get_deformed_coordinate(None, pose_params, sample_coordinates) + + + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + colors_coarse = out['rgb'] + densities_coarse = out['sigma'] + + xyz_coarse = out['xyz'] + + if triplane_crop: + # print(xyz_fine.amin(dim=(0,1))) + # print(xyz_fine.amax(dim=(0,1))) + cropmask = triplane_crop_mask(xyz_coarse, triplane_crop, rendering_options['box_warp']) + densities_coarse[cropmask] = -1e3 + if binarize_clouds: + ccmask = cull_clouds_mask(densities_coarse, binarize_clouds) + densities_coarse[ccmask] = -1e3 + densities_coarse[~ccmask] = 1e3 + elif cull_clouds: + ccmask = cull_clouds_mask(densities_coarse, cull_clouds) + densities_coarse[ccmask] = -1e3 + + colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) + densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) + xyz_coarse = xyz_coarse.reshape(batch_size, num_rays, samples_per_ray, xyz_coarse.shape[-1]) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + # deform the sample_coordinates + if apply_def: + sample_coordinates = self.get_deformed_coordinate(None, pose_params, sample_coordinates) + + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + colors_fine = out['rgb'] + densities_fine = out['sigma'] + xyz_fine = out['xyz'] + if triplane_crop: + # print(xyz_fine.amin(dim=(0,1))) + # print(xyz_fine.amax(dim=(0,1))) + cropmask = triplane_crop_mask(xyz_fine, triplane_crop, rendering_options['box_warp']) + densities_fine[cropmask] = -1e3 + if binarize_clouds: + ccmask = cull_clouds_mask(densities_fine, binarize_clouds) + densities_fine[ccmask] = -1e3 + densities_fine[~ccmask] = 1e3 + elif cull_clouds: + ccmask = cull_clouds_mask(densities_fine, cull_clouds) + densities_fine[ccmask] = -1e3 + xyz_fine = xyz_fine.reshape(batch_size, num_rays, N_importance, xyz_fine.shape[-1]) + colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) + densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) + + # all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + # depths_fine, colors_fine, densities_fine) + all_depths, all_colors, all_densities, all_xyz = self.unify_samples( + depths_coarse, colors_coarse, densities_coarse, xyz_coarse, + depths_fine, colors_fine, densities_fine, xyz_fine, + ) + + # Aggregate + # rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + + all_colors_ = torch.cat([all_colors, all_xyz], dim=-1) + rgb_final_, depth_final, weights = self.ray_marcher(all_colors_, all_densities, all_depths, rendering_options) + rgb_final = rgb_final_[...,:-3] + xyz_final = rgb_final_[...,-3:] + else: + # rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + colors_coarse_ = torch.cat([colors_coarse, xyz_coarse], dim=-1) + rgb_final_, depth_final, weights = self.ray_marcher(colors_coarse_, densities_coarse, depths_coarse, rendering_options) + rgb_final = rgb_final_[...,:-3] + xyz_final = rgb_final_[...,-3:] + + + output = {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + + return output + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + self.plane_axes = self.plane_axes.to(planes[list(planes.keys())[0]].device) + sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'], triplane_depth=options['triplane_depth']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + out['xyz'] = sample_coordinates#.permute(0,2,1)[...,None] + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + # def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + # all_depths = torch.cat([depths1, depths2], dim = -2) + # all_colors = torch.cat([colors1, colors2], dim = -2) + # all_densities = torch.cat([densities1, densities2], dim = -2) + + # _, indices = torch.sort(all_depths, dim=-2) + # all_depths = torch.gather(all_depths, -2, indices) + # all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + # all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + # return all_depths, all_colors, all_densities + def unify_samples(self, depths1, colors1, densities1, xyz1, depths2, colors2, densities2, xyz2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_xyz = torch.cat([xyz1, xyz2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_xyz = torch.gather(all_xyz, -2, indices.expand(-1, -1, -1, all_xyz.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities, all_xyz + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom 0: + self.defer_frames -= 1 + elif self.dump_image: + if 'image' in viz.result: + self.dump_png(viz.result.image) + self.dump_image = False + elif self.dump_gui: + viz.capture_next_frame() + self.dump_gui = False + captured_frame = viz.pop_captured_frame() + if captured_frame is not None: + self.dump_png(captured_frame) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/conditioning_pose_widget.py b/3DPortraitGAN_pyramid/viz/conditioning_pose_widget.py new file mode 100644 index 0000000..90ba693 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/conditioning_pose_widget.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class ConditioningPoseWidget: + def __init__(self, viz): + self.viz = viz + self.pose = dnnlib.EasyDict(yaw=0, pitch=0, anim=False, speed=0.25) + self.pose_def = dnnlib.EasyDict(self.pose) + + self.neck_pose = dnnlib.EasyDict(x=0, y=0, z=0) + self.head_pose = dnnlib.EasyDict(x=0, y=0, z=0) + + def drag(self, dx, dy): + viz = self.viz + self.pose.yaw += -dx / viz.font_size * 3e-2 + self.pose.pitch += -dy / viz.font_size * 3e-2 + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('Cond Pose') + imgui.same_line(viz.label_w) + yaw = self.pose.yaw + pitch = self.pose.pitch + with imgui_utils.item_width(viz.font_size * 5): + changed, (new_yaw, new_pitch) = imgui.input_float2('##frac', yaw, pitch, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.pose.yaw = new_yaw + self.pose.pitch = new_pitch + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) + if dragging: + self.drag(dx, dy) + imgui.same_line() + snapped = dnnlib.EasyDict(self.pose, yaw=round(self.pose.yaw, 1), pitch=round(self.pose.pitch, 1)) + if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.pose != snapped)): + self.pose = snapped + imgui.same_line() + if imgui_utils.button('Reset', width=-1, enabled=(self.pose != self.pose_def)): + self.pose = dnnlib.EasyDict(self.pose_def) + + imgui.text('Cond NeckPose') + imgui.same_line(viz.label_w) + neck_pose_x = self.neck_pose.x + neck_pose_y = self.neck_pose.y + neck_pose_z = self.neck_pose.z + with imgui_utils.item_width(viz.font_size * 10): + changed, (new_neck_pose_x, new_neck_pose_y, new_neck_pose_z) = \ + imgui.input_float3('##neck_pose', neck_pose_x, neck_pose_y, neck_pose_z, format='%+.2f', + flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.neck_pose.x = new_neck_pose_x + self.neck_pose.y = new_neck_pose_y + self.neck_pose.z = new_neck_pose_z + + imgui.text('Cond HeadPose') + imgui.same_line(viz.label_w) + head_pose_x = self.head_pose.x + head_pose_y = self.head_pose.y + head_pose_z = self.head_pose.z + with imgui_utils.item_width(viz.font_size * 10): + changed, (new_head_pose_x, new_head_pose_y, new_head_pose_z) = \ + imgui.input_float3('##head_pose', head_pose_x, head_pose_y, + head_pose_z, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.head_pose.x = new_head_pose_x + self.head_pose.y = new_head_pose_y + self.head_pose.z = new_head_pose_z + + + + + viz.args.conditioning_yaw = self.pose.yaw + viz.args.conditioning_pitch = self.pose.pitch + + viz.args.conditioning_body_pose = [self.neck_pose.x, self.neck_pose.y, self.neck_pose.z, self.head_pose.x, self.head_pose.y, self.head_pose.z] + + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/latent_widget.py b/3DPortraitGAN_pyramid/viz/latent_widget.py new file mode 100644 index 0000000..30ce50c --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/latent_widget.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class LatentWidget: + def __init__(self, viz): + self.viz = viz + self.latent = dnnlib.EasyDict(x=1, y=0, anim=False, speed=0.25) + self.latent_def = dnnlib.EasyDict(self.latent) + self.step_y = 100 + + def drag(self, dx, dy): + viz = self.viz + self.latent.x += dx / viz.font_size * 4e-2 + self.latent.y += dy / viz.font_size * 4e-2 + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('Latent') + imgui.same_line(viz.label_w) + seed = round(self.latent.x) + round(self.latent.y) * self.step_y + with imgui_utils.item_width(viz.font_size * 8): + changed, seed = imgui.input_int('##seed', seed, step=0) + if changed: + self.latent.x = seed + self.latent.y = 0 + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + frac_x = self.latent.x - round(self.latent.x) + frac_y = self.latent.y - round(self.latent.y) + with imgui_utils.item_width(viz.font_size * 5): + changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.latent.x += new_frac_x - frac_x + self.latent.y += new_frac_y - frac_y + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) + if dragging: + self.drag(dx, dy) + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) + _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) + imgui.same_line(round(viz.font_size * 28.7)) + with imgui_utils.item_width(-2 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): + changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) + if changed: + self.latent.speed = speed + imgui.same_line() + snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) + if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): + self.latent = snapped + imgui.same_line() + if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): + self.latent = dnnlib.EasyDict(self.latent_def) + + if self.latent.anim: + self.latent.x += viz.frame_delta * self.latent.speed + viz.args.w0_seeds = [] # [[seed, weight], ...] + for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: + seed_x = np.floor(self.latent.x) + ofs_x + seed_y = np.floor(self.latent.y) + ofs_y + seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) + weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) + if weight > 0: + viz.args.w0_seeds.append([seed, weight]) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/layer_widget.py b/3DPortraitGAN_pyramid/viz/layer_widget.py new file mode 100644 index 0000000..6da2585 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/layer_widget.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class LayerWidget: + def __init__(self, viz): + self.viz = viz + self.prev_layers = None + self.cur_layer = None + self.sel_channels = 3 + self.base_channel = 0 + self.img_scale_db = 0 + self.img_normalize = False + self.fft_show = False + self.fft_all = True + self.fft_range_db = 50 + self.fft_beta = 8 + self.refocus = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + layers = viz.result.get('layers', []) + if self.prev_layers != layers: + self.prev_layers = layers + self.refocus = True + layer = ([layer for layer in layers if layer.name == self.cur_layer] + [None])[0] + if layer is None and len(layers) > 0: + layer = layers[-1] + self.cur_layer = layer.name + num_channels = layer.shape[1] if layer is not None else 0 + base_channel_max = max(num_channels - self.sel_channels, 0) + + if show: + bg_color = [0.16, 0.29, 0.48, 0.2] + dim_color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) + dim_color[-1] *= 0.5 + + # Begin list. + width = viz.font_size * 28 + height = imgui.get_text_line_height_with_spacing() * 12 + viz.spacing + imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) + imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) + imgui.push_style_color(imgui.COLOR_HEADER, 0, 0, 0, 0) + imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, 0.16, 0.29, 0.48, 0.5) + imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, 0.16, 0.29, 0.48, 0.9) + imgui.begin_child('##list', width=width, height=height, border=True, flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR) + + # List items. + for layer in layers: + selected = (self.cur_layer == layer.name) + _opened, selected = imgui.selectable(f'##{layer.name}_selectable', selected) + imgui.same_line(viz.spacing) + _clicked, selected = imgui.checkbox(f'{layer.name}##radio', selected) + if selected: + self.cur_layer = layer.name + if self.refocus: + imgui.set_scroll_here() + viz.skip_frame() # Focus will change on next frame. + self.refocus = False + imgui.same_line(width - viz.font_size * 13) + imgui.text_colored('x'.join(str(x) for x in layer.shape[2:]), *dim_color) + imgui.same_line(width - viz.font_size * 8) + imgui.text_colored(str(layer.shape[1]), *dim_color) + imgui.same_line(width - viz.font_size * 5) + imgui.text_colored(layer.dtype, *dim_color) + + # End list. + if len(layers) == 0: + imgui.text_colored('No layers found', *dim_color) + imgui.end_child() + imgui.pop_style_color(4) + imgui.pop_style_var(1) + + # Begin options. + imgui.same_line() + imgui.begin_child('##options', width=-1, height=height, border=False) + + # RGB & normalize. + rgb = (self.sel_channels == 3) + _clicked, rgb = imgui.checkbox('RGB', rgb) + self.sel_channels = 3 if rgb else 1 + imgui.same_line(viz.font_size * 4) + _clicked, self.img_normalize = imgui.checkbox('Normalize', self.img_normalize) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + if imgui_utils.button('Reset##img_flags', width=-1, enabled=(self.sel_channels != 3 or self.img_normalize)): + self.sel_channels = 3 + self.img_normalize = False + + # Image scale. + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.img_scale_db = imgui.slider_float('##scale', self.img_scale_db, min_value=-40, max_value=40, format='Scale %+.1f dB') + imgui.same_line() + if imgui_utils.button('Reset##scale', width=-1, enabled=(self.img_scale_db != 0)): + self.img_scale_db = 0 + + # Base channel. + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + narrow_w = imgui.get_text_line_height_with_spacing() + with imgui_utils.grayed_out(base_channel_max == 0): + with imgui_utils.item_width(-1 - viz.button_w - narrow_w * 2 - viz.spacing * 3): + _changed, self.base_channel = imgui.drag_int('##channel', self.base_channel, change_speed=0.05, min_value=0, max_value=base_channel_max, format=f'Channel %d/{num_channels}') + imgui.same_line() + if imgui_utils.button('-##channel', width=narrow_w): + self.base_channel -= 1 + imgui.same_line() + if imgui_utils.button('+##channel', width=narrow_w): + self.base_channel += 1 + imgui.same_line() + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + if imgui_utils.button('Reset##channel', width=-1, enabled=(self.base_channel != 0 and base_channel_max > 0)): + self.base_channel = 0 + + # Stats. + stats = viz.result.get('stats', None) + stats = [f'{stats[idx]:g}' if stats is not None else 'N/A' for idx in range(6)] + rows = [ + ['Statistic', 'All channels', 'Selected'], + ['Mean', stats[0], stats[1]], + ['Std', stats[2], stats[3]], + ['Max', stats[4], stats[5]], + ] + height = imgui.get_text_line_height_with_spacing() * len(rows) + viz.spacing + imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) + imgui.begin_child('##stats', width=-1, height=height, border=True) + for y, cols in enumerate(rows): + for x, col in enumerate(cols): + if x != 0: + imgui.same_line(viz.font_size * (4 + (x - 1) * 6)) + if x == 0 or y == 0: + imgui.text_colored(col, *dim_color) + else: + imgui.text(col) + imgui.end_child() + imgui.pop_style_color(1) + + # FFT & all. + _clicked, self.fft_show = imgui.checkbox('FFT', self.fft_show) + imgui.same_line(viz.font_size * 4) + with imgui_utils.grayed_out(not self.fft_show or base_channel_max == 0): + _clicked, self.fft_all = imgui.checkbox('All channels', self.fft_all) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + with imgui_utils.grayed_out(not self.fft_show): + if imgui_utils.button('Reset##fft_flags', width=-1, enabled=(self.fft_show or not self.fft_all)): + self.fft_show = False + self.fft_all = True + + # FFT range. + with imgui_utils.grayed_out(not self.fft_show): + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.fft_range_db = imgui.slider_float('##fft_range_db', self.fft_range_db, min_value=0.1, max_value=100, format='Range +-%.1f dB') + imgui.same_line() + if imgui_utils.button('Reset##fft_range_db', width=-1, enabled=(self.fft_range_db != 50)): + self.fft_range_db = 50 + + # FFT beta. + with imgui_utils.grayed_out(not self.fft_show): + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.fft_beta = imgui.slider_float('##fft_beta', self.fft_beta, min_value=0, max_value=50, format='Kaiser beta %.2f', power=2.63) + imgui.same_line() + if imgui_utils.button('Reset##fft_beta', width=-1, enabled=(self.fft_beta != 8)): + self.fft_beta = 8 + + # End options. + imgui.end_child() + + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + viz.args.layer_name = self.cur_layer if len(layers) > 0 and self.cur_layer != layers[-1].name else None + viz.args.update(sel_channels=self.sel_channels, base_channel=self.base_channel, img_scale_db=self.img_scale_db, img_normalize=self.img_normalize) + viz.args.fft_show = self.fft_show + if self.fft_show: + viz.args.update(fft_all=self.fft_all, fft_range_db=self.fft_range_db, fft_beta=self.fft_beta) + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/performance_widget.py b/3DPortraitGAN_pyramid/viz/performance_widget.py new file mode 100644 index 0000000..deb208a --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/performance_widget.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import array +import numpy as np +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class PerformanceWidget: + def __init__(self, viz): + self.viz = viz + self.gui_times = [float('nan')] * 60 + self.render_times = [float('nan')] * 30 + self.fps_limit = 60 + self.use_vsync = False + self.is_async = False + self.force_fp32 = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + self.gui_times = self.gui_times[1:] + [viz.frame_delta] + if 'render_time' in viz.result: + self.render_times = self.render_times[1:] + [viz.result.render_time] + del viz.result.render_time + + if show: + imgui.text('GUI') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) + imgui.same_line(viz.label_w + viz.font_size * 9) + t = [x for x in self.gui_times if x > 0] + t = np.mean(t) if len(t) > 0 else 0 + imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 14) + imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) + with imgui_utils.item_width(viz.font_size * 6): + _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + self.fps_limit = min(max(self.fps_limit, 5), 1000) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) + _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) + + if show: + imgui.text('Render') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) + imgui.same_line(viz.label_w + viz.font_size * 9) + t = [x for x in self.render_times if x > 0] + t = np.mean(t) if len(t) > 0 else 0 + imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 14) + imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) + _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) + _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) + + viz.set_fps_limit(self.fps_limit) + viz.set_vsync(self.use_vsync) + viz.set_async(self.is_async) + viz.args.force_fp32 = self.force_fp32 + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/pickle_widget.py b/3DPortraitGAN_pyramid/viz/pickle_widget.py new file mode 100644 index 0000000..e85a859 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/pickle_widget.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import glob +import os +import re + +import dnnlib +import imgui +import numpy as np +from gui_utils import imgui_utils + +from . import renderer + +#---------------------------------------------------------------------------- + +def _locate_results(pattern): + return pattern + +#---------------------------------------------------------------------------- + +class PickleWidget: + def __init__(self, viz): + self.viz = viz + self.search_dirs = [] + self.cur_pkl = None + self.user_pkl = '' + self.recent_pkls = [] + self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} + self.browse_refocus = False + self.load('', ignore_errors=True) + + def add_recent(self, pkl, ignore_errors=False): + try: + resolved = self.resolve_pkl(pkl) + if resolved not in self.recent_pkls: + self.recent_pkls.append(resolved) + except: + if not ignore_errors: + raise + + def load(self, pkl, ignore_errors=False): + viz = self.viz + viz.clear_result() + viz.skip_frame() # The input field will change on next frame. + try: + resolved = self.resolve_pkl(pkl) + name = resolved.replace('\\', '/').split('/')[-1] + self.cur_pkl = resolved + self.user_pkl = resolved + viz.result.message = f'Loading {name}...' + viz.defer_rendering() + if resolved in self.recent_pkls: + self.recent_pkls.remove(resolved) + self.recent_pkls.insert(0, resolved) + except: + self.cur_pkl = None + self.user_pkl = pkl + if pkl == '': + viz.result = dnnlib.EasyDict(message='No network pickle loaded') + else: + viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) + if not ignore_errors: + raise + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] + if show: + imgui.text('Pickle') + imgui.same_line(viz.label_w) + changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, + flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), + width=(-1 - viz.button_w * 2 - viz.spacing * 2), + help_text=' | | | | /.pkl') + if changed: + self.load(self.user_pkl, ignore_errors=True) + if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': + imgui.set_tooltip(self.user_pkl) + imgui.same_line() + if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): + imgui.open_popup('recent_pkls_popup') + imgui.same_line() + if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): + imgui.open_popup('browse_pkls_popup') + self.browse_cache.clear() + self.browse_refocus = True + + if imgui.begin_popup('recent_pkls_popup'): + for pkl in recent_pkls: + clicked, _state = imgui.menu_item(pkl) + if clicked: + self.load(pkl, ignore_errors=True) + imgui.end_popup() + + if imgui.begin_popup('browse_pkls_popup'): + def recurse(parents): + key = tuple(parents) + items = self.browse_cache.get(key, None) + if items is None: + items = self.list_runs_and_pkls(parents) + self.browse_cache[key] = items + for item in items: + if item.type == 'run' and imgui.begin_menu(item.name): + recurse([item.path]) + imgui.end_menu() + if item.type == 'pkl': + clicked, _state = imgui.menu_item(item.name) + if clicked: + self.load(item.path, ignore_errors=True) + if len(items) == 0: + with imgui_utils.grayed_out(): + imgui.menu_item('No results found') + recurse(self.search_dirs) + if self.browse_refocus: + imgui.set_scroll_here() + viz.skip_frame() # Focus will change on next frame. + self.browse_refocus = False + imgui.end_popup() + + paths = viz.pop_drag_and_drop_paths() + if paths is not None and len(paths) >= 1: + self.load(paths[0], ignore_errors=True) + + viz.args.pkl = self.cur_pkl + + def list_runs_and_pkls(self, parents): + items = [] + run_regex = re.compile(r'\d+-.*') + pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') + for parent in set(parents): + if os.path.isdir(parent): + for entry in os.scandir(parent): + if entry.is_dir() and run_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) + if entry.is_file() and pkl_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) + + items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) + return items + + def resolve_pkl(self, pattern): + assert isinstance(pattern, str) + assert pattern != '' + + # URL => return as is. + if dnnlib.util.is_url(pattern): + return pattern + + # Short-hand pattern => locate. + path = _locate_results(pattern) + + # Run dir => pick the last saved snapshot. + if os.path.isdir(path): + pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) + if len(pkl_files) == 0: + raise IOError(f'No network pickle found in "{path}"') + path = pkl_files[-1] + + # Normalize. + path = os.path.abspath(path) + return path + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/pose_widget.py b/3DPortraitGAN_pyramid/viz/pose_widget.py new file mode 100644 index 0000000..bcb1f17 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/pose_widget.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class PoseWidget: + def __init__(self, viz): + self.viz = viz + self.pose = dnnlib.EasyDict(yaw=0, pitch=0, anim=False, speed=0.25) + self.pose_def = dnnlib.EasyDict(self.pose) + + self.lookat_point_choice = 0 + self.lookat_point_option = ['auto', 'ffhq', 'shapenet', 'afhq', 'manual'] + self.lookat_point_labels = ['Auto Detect', 'FFHQ Default', 'Shapenet Default', 'AFHQ Default', 'Manual'] + self.lookat_point = (0.0, 0.0, 0.2) + + def drag(self, dx, dy): + viz = self.viz + self.pose.yaw += -dx / viz.font_size * 3e-2 + self.pose.pitch += -dy / viz.font_size * 3e-2 + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('Pose') + imgui.same_line(viz.label_w) + yaw = self.pose.yaw + pitch = self.pose.pitch + with imgui_utils.item_width(viz.font_size * 5): + changed, (new_yaw, new_pitch) = imgui.input_float2('##pose', yaw, pitch, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.pose.yaw = new_yaw + self.pose.pitch = new_pitch + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) + if dragging: + self.drag(dx, dy) + imgui.same_line() + snapped = dnnlib.EasyDict(self.pose, yaw=round(self.pose.yaw, 1), pitch=round(self.pose.pitch, 1)) + if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.pose != snapped)): + self.pose = snapped + imgui.same_line() + if imgui_utils.button('Reset', width=-1, enabled=(self.pose != self.pose_def)): + self.pose = dnnlib.EasyDict(self.pose_def) + + # New line starts here + imgui.text('LookAt Point') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + _clicked, self.lookat_point_choice = imgui.combo('', self.lookat_point_choice, self.lookat_point_labels) + lookat_point = self.lookat_point_option[self.lookat_point_choice] + if lookat_point == 'auto': + self.lookat_point = None + if lookat_point == 'ffhq': + self.lookat_point = (0.0, 0.0, 0.2) + changes_enabled=False + if lookat_point == 'shapenet': + self.lookat_point = (0.0, 0.0, 0.0) + changes_enabled=False + if lookat_point == 'afhq': + self.lookat_point = (0.0, 0.0, 0.0) + changes_enabled=False + if lookat_point == 'manual': + if self.lookat_point is None: + self.lookat_point = (0.0, 0.0, 0.0) + changes_enabled=True + if lookat_point != 'auto': + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) + with imgui_utils.item_width(viz.font_size * 16): + with imgui_utils.grayed_out(not changes_enabled): + _changed, self.lookat_point = imgui.input_float3('##lookat', *self.lookat_point, format='%.2f', flags=(imgui.INPUT_TEXT_READ_ONLY if not changes_enabled else 0)) + + + viz.args.yaw = self.pose.yaw + viz.args.pitch = self.pose.pitch + + viz.args.lookat_point = self.lookat_point + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/pyramid_trigrid_widget.py b/3DPortraitGAN_pyramid/viz/pyramid_trigrid_widget.py new file mode 100644 index 0000000..976733e --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/pyramid_trigrid_widget.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import glob +import os +import re + +import dnnlib +import imgui +import numpy as np +from gui_utils import imgui_utils + +from . import renderer + +#---------------------------------------------------------------------------- + +def _locate_results(pattern): + return pattern + +#---------------------------------------------------------------------------- + +class PyramidTrigridWidget: + def __init__(self, viz): + self.viz = viz + self.search_dirs = [] + self.cur_pyramid_trigrid = None + self.cur_ws = None + self.user_pth = '' + self.recent_pths = [] + self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} + self.browse_refocus = False + self.load('', ignore_errors=True) + + def add_recent(self, pth, ignore_errors=False): + try: + resolved = self.resolve_pth(pth) + if resolved not in self.recent_pths: + self.recent_pths.append(resolved) + except: + if not ignore_errors: + raise + + def load(self, pth, ignore_errors=False): + viz = self.viz + viz.clear_result() + viz.skip_frame() # The input field will change on next frame. + try: + resolved = pth + name = resolved.replace('\\', '/').split('/')[-1] + self.cur_pth = resolved + self.user_pth = resolved + viz.result.message = f'Loading {name}...' + viz.defer_rendering() + if resolved in self.recent_pths: + self.recent_pths.remove(resolved) + self.recent_pths.insert(0, resolved) + except: + self.cur_pth = None + self.user_pth = pth + if pth == '': + viz.result = dnnlib.EasyDict(message='No pyramid tri-grid ckpt loaded') + else: + viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) + if not ignore_errors: + raise + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + recent_pths = [pth for pth in self.recent_pths if pth != self.user_pth] + if show: + imgui.text('Pyramid Tri-Grid Ckpt:') + imgui.same_line(round(viz.font_size * 8.5)) + changed, self.user_pth = imgui_utils.input_text('##pth', self.user_pth, 1024, + flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), + width=(-1 - viz.button_w * 2 - viz.spacing * 2), + help_text='.pth') + if changed: + self.load(self.user_pth, ignore_errors=True) + if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pth != '': + imgui.set_tooltip(self.user_pth) + + imgui.same_line() + if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): + imgui.open_popup('browse_pths_popup') + self.browse_cache.clear() + self.browse_refocus = True + + if imgui.begin_popup('recent_pths_popup'): + for pth in recent_pths: + clicked, _state = imgui.menu_item(pth) + if clicked: + self.load(pth, ignore_errors=True) + imgui.end_popup() + + if imgui.begin_popup('browse_pths_popup'): + def recurse(parents): + key = tuple(parents) + items = self.browse_cache.get(key, None) + if items is None: + items = self.list_runs_and_pths(parents) + self.browse_cache[key] = items + for item in items: + if item.type == 'run' and imgui.begin_menu(item.name): + recurse([item.path]) + imgui.end_menu() + if item.type == 'pth': + clicked, _state = imgui.menu_item(item.name) + if clicked: + self.load(item.path, ignore_errors=True) + if len(items) == 0: + with imgui_utils.grayed_out(): + imgui.menu_item('No results found') + recurse(self.search_dirs) + if self.browse_refocus: + imgui.set_scroll_here() + viz.skip_frame() # Focus will change on next frame. + self.browse_refocus = False + imgui.end_popup() + + paths = viz.pop_drag_and_drop_paths() + if paths is not None and len(paths) >= 1: + self.load(paths[0], ignore_errors=True) + + viz.args.pyramid_tri_grid_ckpt = self.cur_pth + + def list_runs_and_pths(self, parents): + items = [] + run_regex = re.compile(r'\d+-.*') + pth_regex = re.compile(r'network-snapshot-\d+\.pth') + for parent in set(parents): + if os.path.isdir(parent): + for entry in os.scandir(parent): + if entry.is_dir() and run_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) + if entry.is_file() and pth_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='pth', name=entry.name, path=os.path.join(parent, entry.name))) + + items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) + return items + + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/render_depth_sample_widget.py b/3DPortraitGAN_pyramid/viz/render_depth_sample_widget.py new file mode 100644 index 0000000..27c48f7 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/render_depth_sample_widget.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class RenderDepthSampleWidget: + def __init__(self, viz): + self.viz = viz + self.depth_mult = 2 + self.depth_importance_mult = 2 + self.render_types = [.5, 1, 2, 4] + self.labels = ['0.5x', '1x', '2x', '4x'] + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + + if show: + imgui.text('Render Type') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 4): + _clicked, self.depth_mult = imgui.combo('Depth Sample Multiplier', self.depth_mult, self.labels) + imgui.same_line(viz.label_w + viz.font_size * 16 + viz.spacing * 2) + with imgui_utils.item_width(viz.font_size * 4): + _clicked, self.depth_importance_mult = imgui.combo('Depth Sample Importance Multiplier', self.depth_importance_mult, self.labels) + + viz.args.depth_mult = self.render_types[self.depth_mult] + viz.args.depth_importance_mult = self.render_types[self.depth_importance_mult] + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/render_type_widget.py b/3DPortraitGAN_pyramid/viz/render_type_widget.py new file mode 100644 index 0000000..fcfff4e --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/render_type_widget.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class RenderTypeWidget: + def __init__(self, viz): + self.viz = viz + self.render_type = 0 + self.render_types = ['image_raw', 'image_depth'] + self.labels = ['RGB Image', 'Depth Image'] + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + + if show: + imgui.text('Render Type') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 10): + _clicked, self.render_type = imgui.combo('', self.render_type, self.labels) + + viz.args.render_type = self.render_types[self.render_type] + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/renderer.py b/3DPortraitGAN_pyramid/viz/renderer.py new file mode 100644 index 0000000..dc041ef --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/renderer.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import sys +import copy +import traceback +import numpy as np +import torch +import torch.fft +import torch.nn +import matplotlib.cm +import dnnlib +from torch_utils.ops import upfirdn2d +import legacy # pylint: disable=import-error + +from camera_utils import LookAtPoseSampler +import os + +#---------------------------------------------------------------------------- + +class CapturedException(Exception): + def __init__(self, msg=None): + if msg is None: + _type, value, _traceback = sys.exc_info() + assert value is not None + if isinstance(value, CapturedException): + msg = str(value) + else: + msg = traceback.format_exc() + assert isinstance(msg, str) + super().__init__(msg) + +#---------------------------------------------------------------------------- + +class CaptureSuccess(Exception): + def __init__(self, out): + super().__init__() + self.out = out + +#---------------------------------------------------------------------------- + +def _sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def _lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, _sinc(x), torch.zeros_like(x)) + +#---------------------------------------------------------------------------- + +def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in) + fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = _lanczos_window(xi, a) * _lanczos_window(yi, a) + wo = _lanczos_window(xo, a) * _lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- + +def _apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- + +class Renderer: + def __init__(self): + self._device = torch.device('cuda') + self._pkl_data = dict() # {pkl: dict | CapturedException, ...} + self._networks = dict() # {cache_key: torch.nn.Module, ...} + self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} + self._cmaps = dict() # {name: torch.Tensor, ...} + self._is_timing = False + self._start_event = torch.cuda.Event(enable_timing=True) + self._end_event = torch.cuda.Event(enable_timing=True) + self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} + + self.input_data = dict() + + def render(self, **args): + self._is_timing = True + self._start_event.record(torch.cuda.current_stream(self._device)) + res = dnnlib.EasyDict() + try: + self._render_impl(res, **args) + except: + res.error = CapturedException() + self._end_event.record(torch.cuda.current_stream(self._device)) + if 'image' in res: + res.image = self.to_cpu(res.image).numpy() + if 'stats' in res: + res.stats = self.to_cpu(res.stats).numpy() + if 'error' in res: + res.error = str(res.error) + if self._is_timing: + self._end_event.synchronize() + res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 + self._is_timing = False + return res + + def get_pyramid_tri_grid_ws(self, pyramid_tri_grid_ckpt, device): + + data = self.input_data.get(pyramid_tri_grid_ckpt, None) + if data is None: + + print(f'Loading "{pyramid_tri_grid_ckpt}"... ', end='', flush=True) + ckpt = torch.load(pyramid_tri_grid_ckpt, map_location=lambda storage, loc: storage)['model'] + trigrid = { + 8: ckpt['trigrids_8'].to(device).detach(), + 16: ckpt['trigrids_16'].to(device).detach(), + 32: ckpt['trigrids_32'].to(device).detach(), + 64: ckpt['trigrids_64'].to(device).detach(), + 128: ckpt['trigrids_128'].to(device).detach(), + 256: ckpt['trigrids_256'].to(device).detach(), + 512: ckpt['trigrids_512'].to(device).detach(), + } + ws = ckpt['ws'].to(device) + print('Done.') + self.input_data[pyramid_tri_grid_ckpt] = {'trigrid': trigrid, 'ws': ws} + + else: + trigrid = data['trigrid'] + ws = data['ws'] + + return trigrid, ws + + + def get_network(self, pkl, key, **tweak_kwargs): + data = self._pkl_data.get(pkl, None) + if data is None: + print(f'Loading "{pkl}"... ', end='', flush=True) + try: + with dnnlib.util.open_url(pkl, verbose=False) as f: + data = legacy.load_network_pkl(f) + print('Done.') + except: + data = CapturedException() + print('Failed!') + self._pkl_data[pkl] = data + self._ignore_timing() + if isinstance(data, CapturedException): + raise data + + orig_net = data[key] + cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) + net = self._networks.get(cache_key, None) + if net is None: + try: + net = copy.deepcopy(orig_net) + net = self._tweak_network(net, **tweak_kwargs) + net.to(self._device) + except: + net = CapturedException() + self._networks[cache_key] = net + self._ignore_timing() + if isinstance(net, CapturedException): + raise net + + + return net + + def _tweak_network(self, net): + # Print diagnostics. + + # RELOAD_MODULES = False + # if RELOAD_MODULES: + # from training.triplane import TriPlaneGenerator + # from torch_utils import misc + # print("Reloading Modules!") + # net_new = TriPlaneGenerator(*net.init_args, **net.init_kwargs).eval().requires_grad_(False).to(self._device) + # misc.copy_params_and_buffers(net, net_new, require_all=True) + # net_new.neural_rendering_resolution = net.neural_rendering_resolution + # net_new.rendering_kwargs = net.rendering_kwargs + # net = net_new + # # net.rendering_kwargs['ray_start'] = 'auto' + # # net.rendering_kwargs['ray_end'] = 'auto' + # # net.rendering_kwargs['avg_camera_pivot'] = [0, 0, 0] + + if True: + print("Reloading Modules!") + from training.smpl_triplane import TriPlaneGenerator + from torch_utils import misc + print("Reloading Modules!") + init_kwargs = net.init_kwargs + print('G.init_args: ', net.init_args) + print('G.init_kwargs: ', init_kwargs) + G_new = TriPlaneGenerator(*net.init_args, **init_kwargs).eval().requires_grad_(False).to(self._device) + misc.copy_params_and_buffers(net, G_new, require_all=False) + G_new.neural_rendering_resolution = net.neural_rendering_resolution + G_new.rendering_kwargs = net.rendering_kwargs + G_new.batch_size = 1 + G_new.set_batch_size(1) + net = G_new + print('>>>> G batch: ', net.batch_size) + + + return net + + def _get_pinned_buf(self, ref): + key = (tuple(ref.shape), ref.dtype) + buf = self._pinned_bufs.get(key, None) + if buf is None: + buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() + self._pinned_bufs[key] = buf + return buf + + def to_device(self, buf): + return self._get_pinned_buf(buf).copy_(buf).to(self._device) + + def to_cpu(self, buf): + return self._get_pinned_buf(buf).copy_(buf).clone() + + def _ignore_timing(self): + self._is_timing = False + + def _apply_cmap(self, x, name='viridis'): + cmap = self._cmaps.get(name, None) + if cmap is None: + cmap = matplotlib.cm.get_cmap(name) + cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] + cmap = self.to_device(torch.from_numpy(cmap)) + self._cmaps[name] = cmap + hi = cmap.shape[0] - 1 + x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) + x = torch.nn.functional.embedding(x, cmap) + return x + + + def _render_impl(self, res, + pkl = None, + pyramid_tri_grid_ckpt = None, + w0_seeds = [[0, 1]], + stylemix_idx = [], + stylemix_seed = 0, + trunc_psi = 1, + trunc_cutoff = 0, + random_seed = 0, + noise_mode = 'const', + force_fp32 = False, + layer_name = None, + sel_channels = 3, + base_channel = 0, + img_scale_db = 0, + img_normalize = False, + fft_show = False, + fft_all = True, + fft_range_db = 50, + fft_beta = 8, + input_transform = None, + untransform = False, + + yaw = 0, + pitch = 0, + lookat_point = (0, 0.0649, 0), + conditioning_yaw = 0, + conditioning_pitch = 0, + conditioning_body_pose = None, + body_pose = None, + focal_length = 4.2647, + render_type = 'image', + + do_backbone_caching = False, + + depth_mult = 1, + depth_importance_mult = 1, + ): + if not os.path.exists(pyramid_tri_grid_ckpt) or not os.path.exists(pkl): + res.error = 'Pyramid Tri-Grid or pkl file does not exist' + return + if body_pose is None: + body_pose = np.zeros((1, 6), dtype=np.float32) + else: + body_pose = np.array(body_pose, dtype=np.float32) + body_pose = np.reshape(body_pose, (1, -1)) + + + + + # Dig up network details. + G = self.get_network(pkl, 'G_ema').eval().requires_grad_(False).to('cuda') + res.img_resolution = G.img_resolution + res.num_ws = G.backbone.num_ws + res.has_noise = any('noise_const' in name for name, _buf in G.backbone.named_buffers()) + res.has_input_transform = (hasattr(G.backbone, 'input') and hasattr(G.backbone.input, 'transform')) + + # set G rendering kwargs + if 'depth_resolution_default' not in G.rendering_kwargs: + G.rendering_kwargs['depth_resolution_default'] = G.rendering_kwargs['depth_resolution'] + G.rendering_kwargs['depth_resolution_importance_default'] = G.rendering_kwargs['depth_resolution_importance'] + + G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution_default'] * depth_mult) + G.rendering_kwargs['depth_resolution_importance'] = int(G.rendering_kwargs['depth_resolution_importance_default'] * depth_importance_mult) + + # G.init_kwargs.batch_size = 1 + + pyramid_tri_grid,ws = self.get_pyramid_tri_grid_ws(pyramid_tri_grid_ckpt,self._device) + + + # Set input transform. + if res.has_input_transform: + m = np.eye(3) + try: + if input_transform is not None: + m = np.linalg.inv(np.asarray(input_transform)) + except np.linalg.LinAlgError: + res.error = CapturedException() + G.synthesis.input.transform.copy_(torch.from_numpy(m)) + + # Generate random latents. + + if lookat_point is None: + #camera_pivot = torch.tensor(G.rendering_kwargs.get('avg_camera_pivot', (0, 0, 0))) + camera_pivot = torch.tensor([0, 0.0649, 0]) + else: + # override lookat point provided + camera_pivot = torch.tensor(lookat_point) + camera_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7) + + + # Run mapping network. + # w_avg = G.mapping.w_avg + # Run synthesis network. + synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32, cache_backbone=do_backbone_caching) + torch.manual_seed(random_seed) + + # Set camera params + pose = LookAtPoseSampler.sample(3.14/2 + yaw, 3.14/2 + pitch, camera_pivot, radius=camera_radius) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]]) + c = torch.cat([pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1).to(ws.device) + + + + # if body pose is not 0: + if not (body_pose == 0).all(): + apply_def = True + body_pose = torch.tensor(body_pose).to(ws.device) + else: + apply_def = False + body_pose = None + + + out = self.run_tri_grid_render(G, ws,pyramid_tri_grid, c) + + + # Untransform. + if untransform and res.has_input_transform: + out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d. + + # Select channels and compute statistics. + if type(out) == dict: + # is model output. query render type + out = out[render_type][0].to(torch.float32) + else: + out = out[0].to(torch.float32) + + if sel_channels > out.shape[0]: + sel_channels = 1 + base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0) + sel = out[base_channel : base_channel + sel_channels] + res.stats = torch.stack([ + out.mean(), sel.mean(), + out.std(), sel.std(), + out.norm(float('inf')), sel.norm(float('inf')), + ]) + + # normalize if type is 'image_depth' + if render_type == 'image_depth': + out -= out.min() + out /= out.max() + + out -= .5 + out *= -2 + + # Scale and convert to uint8. + img = sel + if img_normalize: + img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) + img = img * (10 ** (img_scale_db / 20)) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + res.image = img + + # FFT. + if fft_show: + sig = out if fft_all else sel + sig = sig.to(torch.float32) + sig = sig - sig.mean(dim=[1,2], keepdim=True) + sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None] + sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :] + fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0) + fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1]) + fft = (fft / fft.mean()).log10() * 10 # dB + fft = self._apply_cmap((fft / fft_range_db + 1) / 2) + res.image = torch.cat([img.expand_as(fft), fft], dim=1) + + + + def run_synthesis_net(net, *args, capture_layer=None, **kwargs): # => out, layers + submodule_names = {mod: name for name, mod in net.named_modules()} + unique_names = set() + layers = [] + + def module_hook(module, _inputs, outputs): + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]] + for idx, out in enumerate(outputs): + if out.ndim == 5: # G-CNN => remove group dimension. + out = out.mean(2) + name = submodule_names[module] + if name == '': + name = 'output' + if len(outputs) > 1: + name += f':{idx}' + if name in unique_names: + suffix = 2 + while f'{name}_{suffix}' in unique_names: + suffix += 1 + name += f'_{suffix}' + unique_names.add(name) + shape = [int(x) for x in out.shape] + dtype = str(out.dtype).split('.')[-1] + layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype)) + if name == capture_layer: + raise CaptureSuccess(out) + + hooks = [module.register_forward_hook(module_hook) for module in net.modules()] + try: + out = net.synthesis(*args, **kwargs) + except CaptureSuccess as e: + out = e.out + for hook in hooks: + hook.remove() + return out, layers + + @staticmethod + def run_tri_grid_render(net, w, trigrid,c): # => out, layers + out = net.render_planes(ws=w, planes=trigrid, c=c[0:1], noise_mode='const', + neural_rendering_resolution=256, chunk=4096) + return out + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/stylemix_widget.py b/3DPortraitGAN_pyramid/viz/stylemix_widget.py new file mode 100644 index 0000000..0b84d64 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/stylemix_widget.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class StyleMixingWidget: + def __init__(self, viz): + self.viz = viz + self.seed_def = 1000 + self.seed = self.seed_def + self.animate = False + self.enables = [] + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + num_ws = viz.result.get('num_ws', 0) + num_enables = viz.result.get('num_ws', 18) + self.enables += [False] * max(num_enables - len(self.enables), 0) + + if show: + imgui.text('Stylemix') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): + _changed, self.seed = imgui.input_int('##seed', self.seed) + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + with imgui_utils.grayed_out(num_ws == 0): + _clicked, self.animate = imgui.checkbox('Anim', self.animate) + + pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w + pos1 = pos2 - imgui.get_text_line_height() - viz.spacing + pos0 = viz.label_w + viz.font_size * 12 + imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) + for idx in range(num_enables): + imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) + if idx == 0: + imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) + with imgui_utils.grayed_out(num_ws == 0): + _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) + if imgui.is_item_hovered(): + imgui.set_tooltip(f'{idx}') + imgui.pop_style_var(1) + + imgui.same_line(pos2) + imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) + with imgui_utils.grayed_out(num_ws == 0): + if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): + self.seed = self.seed_def + self.animate = False + self.enables = [False] * num_enables + + if any(self.enables[:num_ws]): + viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] + viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) + if self.animate: + self.seed += 1 + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/trunc_noise_widget.py b/3DPortraitGAN_pyramid/viz/trunc_noise_widget.py new file mode 100644 index 0000000..c811d63 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/trunc_noise_widget.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class TruncationNoiseWidget: + def __init__(self, viz): + self.viz = viz + self.prev_num_ws = 0 + self.trunc_psi = 0.12 + self.trunc_cutoff = 7 + self.noise_enable = True + self.noise_seed = 0 + self.noise_anim = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + num_ws = viz.result.get('num_ws', 0) + has_noise = viz.result.get('has_noise', False) + if num_ws > 0 and num_ws != self.prev_num_ws: + if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: + self.trunc_cutoff = num_ws + self.prev_num_ws = num_ws + + if show: + imgui.text('Truncate') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): + _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') + imgui.same_line() + if num_ws == 0: + imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) + else: + with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): + changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') + if changed: + self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) + + with imgui_utils.grayed_out(not has_noise): + imgui.same_line() + _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) + imgui.same_line(viz.font_size * 28.7) + with imgui_utils.grayed_out(not self.noise_enable): + with imgui_utils.item_width(-3 - viz.button_w - viz.spacing - viz.font_size * 4): + _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) + imgui.same_line(spacing=0) + _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) + + is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) + is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) + with imgui_utils.grayed_out(is_def_trunc and not has_noise): + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): + self.prev_num_ws = num_ws + self.trunc_psi = 0.12 + self.trunc_cutoff = 7 + self.noise_enable = True + self.noise_seed = 0 + self.noise_anim = False + + if self.noise_anim: + self.noise_seed += 1 + viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) + viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') + +#---------------------------------------------------------------------------- diff --git a/3DPortraitGAN_pyramid/viz/zoom_widget.py b/3DPortraitGAN_pyramid/viz/zoom_widget.py new file mode 100644 index 0000000..40aad64 --- /dev/null +++ b/3DPortraitGAN_pyramid/viz/zoom_widget.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from inspect import formatargvalues +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class ZoomWidget: + def __init__(self, viz): + self.viz = viz + self.fov = 12.447863 + self.fov_default = 12.447863 + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('FOV') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 10): + _changed, self.fov = imgui.slider_float('##fov', self.fov, 12, 45, format='%.2f Degrees') + + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) + snapped = round(self.fov) + if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.fov != snapped)): + self.fov = snapped + imgui.same_line() + if imgui_utils.button('Reset', width=-1, enabled=(abs(self.fov - self.fov_default)) > .01): + self.fov = self.fov_default + + viz.args.focal_length = float(1 / (np.tan(self.fov * 3.14159 / 360) * 1.414)) +#---------------------------------------------------------------------------- diff --git a/README.md b/README.md new file mode 100644 index 0000000..3ee0e34 --- /dev/null +++ b/README.md @@ -0,0 +1,399 @@ +# Portrait3D + +> **[SIGGRAPH 2024] Portrait3D: Text-Guided High-Quality 3D Portrait Generation Using Pyramid Representation and GANs Prior** +> +> [Yiqian Wu](https://onethousandwu.com/), [Hao Xu](https://xh38.github.io/), [Xiangjun Tang](https://yuyujunjun.github.io/), [Xien Chen](https://vision.cs.yale.edu/members/xien-chen.html), [Siyu Tang](https://inf.ethz.ch/people/person-detail.MjYyNzgw.TGlzdC8zMDQsLTg3NDc3NjI0MQ==.html), Zhebin Zhang, Chen Li, [Xiaogang Jin*](http://www.cad.zju.edu.cn/home/jin) + +![1f31e](assets/1f31e.png)[Paper]() ![1f431](assets/1f431.png)[Supplementary (Google Drive)]() ![1f98b](assets/1f98b.png)[Project Page]() + +This is the official code repository for our SIG'24 paper: "Portrait3D: Text-Guided High-Quality 3D Portrait Generation Using Pyramid Representation and GANs Prior". + +![Representative_Image](./assets/Representative_Image.jpg) + + +## News ✨ + +- Our paper has been **accepted by SIGGRAPH 2024** ![1f973](assets/1f973.png)! +- We have released all the source code and pre-trained models![1f389](./assets/1f389.png)! + + +## Requirements + +1. Tested on Python 3.8 +3. At least 12 GB of memory +4. Tested on NVIDIA RTX 3080Ti with 12 GB of memory (Windows, 1.5h per portrait) +5. Tested on NVIDIA RTX 4090 with 24 GB of memory (Linux, 0.5h per portrait) +6. CUDA>=11.6 + +## Installation + +Clone this repo to `$PROJECT_ROOT$`. + +**Create environment** + +``` +cd $PROJECT_ROOT$ +conda env create -f environment.yaml +conda activate text_to_3dportrait +``` + +**Torch and torchvision Installation** + +``` +pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html +``` + +**OSMesa Dependencies (For Linux)** + +``` +sudo apt install libosmesa6 libosmesa6-dev +``` + +**Installing Additional Requirements** + +``` +pip install -r requirements.txt +``` + +**kaolin Installation** + +``` +pip install kaolin==0.13.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.1_cu116.html +``` + +**Stable-diffusion Installation** + +``` +cd stable-diffusion +pip install -e . +cd .. +``` + + + +**SMPL Model Setup** + +1. Download [SMPL_python_v.1.0.0.zip](https://smpl.is.tue.mpg.de/download.php) (version 1.0.0 for Python 2.7 (female/male. 10 shape PCs) ). Save `basicModel_f_lbs_10_207_0_v1.0.0.pkl` to `3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_FEMALE.pkl`, save `basicModel_m_lbs_10_207_0_v1.0.0.pkl` to `3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_MALE.pkl`. + +2. Download [SMPLIFY_CODE_V2.ZIP](http://smplify.is.tue.mpg.de/), and save `basicModel_neutral_lbs_10_207_0_v1.0.0.pkl` to `3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_NEUTRAL.pkl`. + +| Download Link | Save Path | +| ------------------------------------------------------------ | -------------------------------------------------------- | +| [basicModel_f_lbs_10_207_0_v1.0.0.pkl](https://smpl.is.tue.mpg.de/download.php) | 3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_FEMALE.pkl | +| [basicModel_m_lbs_10_207_0_v1.0.0.pkl](https://smpl.is.tue.mpg.de/download.php) | 3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_MALE.pkl | +| [basicModel_neutral_lbs_10_207_0_v1.0.0.pkl](http://smplify.is.tue.mpg.de/) | 3DPortraitGAN_pyramid/smplx_models/smpl/SMPL_NEUTRAL.pkl | + + + +## Inference + +### 3DPortraitGAN_pyramid Model + +Our 3DPortraitGAN_pyramid draws inspiration from the 3D-aware StyleGAN2 backbone implemented in [SeanChenxy/Mimic3D](https://github.com/SeanChenxy/Mimic3D), and integrates concepts of mask guidance, background synthesis, and tri-grid representation adapted from [SizheAn/PanoHead](https://github.com/SizheAn/PanoHead). We extend our sincere gratitude for these significant contributions! + +#### (Recommended) Pretrained models + +Download the pre-trained model of 3DPortraitGAN_pyramid: + +| Download Link | Description | Save Path | +| ------------------------------------------------------------ | --------------------------------------------------- | ------------------------------ | +| [model_512.pkl](https://drive.google.com/file/d/1P6k4UwGGNmxa6-rQr2oyIOmAPiLAd_WE/view?usp=sharing) | Pre-trained model of 3DPortraitGAN_pyramid | ./3DPortraitGAN_pyramid/models | +| [model_512.json](https://drive.google.com/file/d/1R6FoQXi4PyIvXtOVoKRohfOXkEkWXdJb/view?usp=sharing) | Pose prediction parameters of 3DPortraitGAN_pyramid | ./3DPortraitGAN_pyramid/models | +| [decoder_512.ckpt](https://drive.google.com/file/d/1r0Lqu1TMm-1Pjj8K963RVM_y72OglJdu/view?usp=sharing) | Decoder checkpoint extracted from model_512.pkl | ./3DPortraitGAN_pyramid/models | +| [vgg16.pt](https://drive.google.com/file/d/1av5jH9jzuOobV9s2gyzx0w9a4xqco82H/view?usp=sharing) | vgg16 | ./3DPortraitGAN_pyramid/models | + +#### (Optional) Training + +Omit this section if utilizing the pre-trained 3DPortraitGAN_pyramid model aforementioned. + +For those interested in the training process, we kindly direct you to our training instructions available [here](https://github.com/oneThousand1000/Portrait3D/tree/main/3DPortraitGAN_pyramid). + + + +### Random Image Generation + +#### Preparing Prompts + +First, prepare your prompts. These should be organized in the following structure: + +``` +test_data +│ +└─── 001 +│ │ +│ └─── prompt.txt (should initiate with "upper body photo") +└─── 002 +│ │ +│ └─── prompt.txt (should initiate with "upper body photo") +└─── ... +``` + +An example is available in `$PROJECT_ROOT$/test_data`. + + + +#### Image generation + +Download the Realistic_Vision_V5.1_noVAE model [here](https://huggingface.co/SG161222/Realistic_Vision_V5.1_noVAE). + +We employ the original stable diffusion in this use case. To convert the diffusers-version model to the original-stable-diffusion-version, follow the steps below: + +``` +cd stable-diffusion + +activate text_to_3dportrait + +git clone git@github.com:huggingface/diffusers.git + +cd diffusers/scripts + +python convert_diffusers_to_original_stable_diffusion.py --model_path $PATH_of_Realistic_Vision_V5.1_noVAE$ --checkpoint_path $PATH_of_Realistic_Vision_V5.1_noVAE$/realisticVisionV51_v51VAE.ckpt + +cd ../../../ +``` + +Then randomly generate images: + +``` +cd stable-diffusion + +activate text_to_3dportrait + +python get_test_data_df.py --test_data_dir ../test_data --sample_num 6 --scale 5 --df_ckpt $PATH_of_Realistic_Vision_V5.1_noVAE$/realisticVisionV51_v51VAE.ckpt + +cd .. +``` + +The generated images will be stored at `$PROJECT_ROOT$/test_data/image_id/samples` + +**Note:** We discovered that using a smaller scale (for example, ` --scale 3`) tends to generate superior results for specific characters, like ''Tyrion Lannister in the Game of Thrones''. Feel free to experiment with different scales to improve the outcome. + + + +#### Image Processing + +Our image processing code is largely adapted from [hongsukchoi/3DCrowdNet_RELEASE](https://github.com/hongsukchoi/3DCrowdNet_RELEASE). + +**Installation** + +```text +conda create -n portrait3d_data python=3.8 + +activate portrait3d_data + +cd data_processing + +pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/torch_stable.html + +pip install -r requirements.txt + +python -m pip install -e detectron2 + +cd .. +``` + + + +For windows: + +``` +pip install pywin32==306 +``` + + + +For windows users who experience errors during detectron2 installation, please open a `x64 Native Tools Command Prompt` for Visual Studio and execute `python -m pip install -e detectron2`. + + + +**Pretrained models** + +| Download Link | Save Path | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| [R_101_FPN_DL_soft_s1x.pkl](https://drive.google.com/file/d/1rgrW9bAVbarft57mogUfawRSu2JCUKIT/view?usp=sharing) | `./data_processing/detectron2/projects/DensePose` | +| [phi_smpl_27554_256.pkl](https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_smpl_27554_256.pkl) | `./data_processing/detectron2/projects/DensePose` | +| [pose_higher_hrnet_w32_512.pth](https://drive.google.com/drive/folders/1zJbBbIHVQmHJp89t5CD1VF5TIzldpHXn) | `./data_processing/HigherHRNet-Human-Pose-Estimation/models/pytorch/pose_coco` | +| [crowdhuman_yolov5m.pt](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing) | `./data_processing/yolov5_crowdhuman` | +| [basicModel_neutral_lbs_10_207_0_v1.0.0.pkl](http://smplify.is.tue.mpg.de/) | `./data_processing/common/utils/smplpytorch/smplpytorch/native/models` | +| [VPOSER_CKPT](https://drive.google.com/drive/folders/1KNw99d4-_6DqYXfBp2S3_4OMQ_nMW0uQ?usp=sharing) | `./data_processing/common/utils/human_model_files/smpl/VPOSER_CKPT` | +| [J_regressor_extra.npy](https://drive.google.com/file/d/1B9e65ahe6TRGv7xE45sScREAAznw9H4t/view?usp=sharing) | `./data_processing/data` | +| [demo_checkpoint.pth.tar](https://drive.google.com/drive/folders/1YYQHbtxvdljqZNo8CIyFOmZ5yXuwtEhm?usp=sharing) | `./data_processing/demo` | + +If you encounter `RuntimeError: Subtraction, the - operator, with a bool tensor is not supported.`, you may refer to [this issue](https://github.com/mks0601/I2L-MeshNet_RELEASE/issues/6#issuecomment-675152527) for a solution or change L301~L304 of `anaconda3/lib/python3.8/site-packages/torchgeometry/core/conversion.py` to below: + +``` +mask_c0 = mask_d2.float() * mask_d0_d1.float() +mask_c1 = mask_d2.float() * (1 - mask_d0_d1.float()) +mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1.float() +mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float()) +``` + + + +Then process the randomly generated images to produce aligned images following the alignment setting of 3DPortraitGAN_pyramid: + +``` +cd data_processing + +activate portrait3d_data +python preprocess_img_for_inversion.py --test_data_dir=$PROJECT_ROOT$/test_data + +cd .. +``` + + + +**Note:** Manually review and discard any subpar images located in `$PROJECT_ROOT$/test_data/image_id/samples_new_crop/aligned_images`. For optimal inversion results, it is recommended to maintain an aligned image with a frontal view and minor body poses. + + + +### 3D Portrait Inversion + +**Inversion** + +Before proceeding further, always ensure that you have removed all unsatisfactory images in `test_data/image_id/samples_new_crop/aligned_images`. This step is crucial to prevent suboptimal results. + +Notice that we only run projection for the first image in `test_data/image_id/samples_new_crop/aligned_images`. + +``` +cd 3DPortraitGAN_pyramid + +activate text_to_3dportrait + +python run_inversion_with_pose_optimization.py \ + --model_pkl=./models/model_512.pkl \ + --pose_prediction_kwargs_path=./models/model_512.json \ + --test_data_dir=../test_data \ + --inversion_name=final_inversion \ + --with_pose_optim +``` + + + +**Generate Pyramid Tri-grid from Inversion results** + +``` +python run_trigrid_gen.py \ + --network=./models/model_512.pkl \ + --inversion_name=final_inversion + +cd .. +``` + + + +### 3D Portrait Generation and Optimization + +Our image generation code is largely adapted from [ashawkey/stable-dreamfusion](https://github.com/ashawkey/stable-dreamfusion). We express our gratitude for their significant contributions! + +``` +cd stable-dreamfusion-3DPortrait + +python portrait3d_main.py \ + --trigrid_decoder_ckpt=../3DPortraitGAN_pyramid/models/decoder_512.ckpt \ + --inversion_name=final_inversion \ + --network_path=../3DPortraitGAN_pyramid/models/model_512.pkl \ + --test_data_dir=../test_data \ + --df_ckpt=$PATH_of_Realistic_Vision_V5.1_noVAE$ +``` + +The results will be stored and organized as: + +``` +stable-dreamfusion-3DPortrait/output/text_to_3dportrait/image_id +│ +└─── trigrid.pkl [Original pyramid tri-grid generated from inversion results] +│ +└─── validation [SDS validation images] +│ +└─── checkpoints [SDS checkpoints] +│ +└─── run [SDS run file] +│ +└─── results [SDS rendering results] +| +└─── data [21 rendered views, refer to Section 3.5 in our paper] +| +└─── update_data [21 refined views, refer to Section 3.5 in our paper] +| +└─── log [Pyramid tri-grid optimization log files, refer to Section 3.5 in our paper] +│ │ +│ └─── ckpt +│ │ │ +│ │ └─── epoch_00019.pth [Final pyramid tri-grid] +│ └─── img +│ +└─── results_final [Final rendering results] +``` + + + +## Results Gallery + +We offer a gallery of 300 3D portraits (with their corresponding prompts) generated by our method, all viewable and accessible on [huggingface](https://huggingface.co/datasets/onethousand/Portrait3D_gallery). + +``` +Portrait3D_gallery +│ +└─── 000 +│ │ +│ └─── 000_pyramid_trigrid.pth [the pyramid trigrid file] +│ │ +│ └─── 000_prompt.txt [the prompt] +│ │ +│ └─── 000_preview.png [the preview image] +│ │ +│ └─── ... +└─── 001 +│ │ +│ └─── ... +└─── 002 +│ │ +│ └─── ... +│ +└─── ... +``` + +To visualize these 3D portraits, use the following visualizer: + +``` +cd 3DPortraitGAN_pyramid + +activate text_to_3dportrait + +python pyramid_trigrid_visualizer.py +``` + +Input the path of your `model_512.pkl` into the `Pickle` field, and input the pyramid tri-grid path into the `Pyramid Tri-Grid Ckpt` field. + +Please observe that we **maintain the neural rendering resolution at 256** for optimal rendering speed. + + +Enjoy traversing through these results 😉! + + + +## Contact + +[onethousand@zju.edu.cn](mailto:onethousand@zju.edu.cn) / [onethousand1250@gmail.com](mailto:onethousand1250@gmail.com) + + + +## Citation + +If you find this project helpful to your research, please consider citing: + +``` +Coming soon. +``` + + + +## Acknowledgements + +The work is supported by the Information Technology Center and State Key Lab of CAD&CG, Zhejiang University. We extend our sincere gratitude for the generous provision of necessary computing resources. + +We also want to express our thanks to those in the open-source community for their valuable contributions. + + + diff --git a/assets/1f31e.png b/assets/1f31e.png new file mode 100644 index 0000000..322c9c1 Binary files /dev/null and b/assets/1f31e.png differ diff --git a/assets/1f389.png b/assets/1f389.png new file mode 100644 index 0000000..b796f8d Binary files /dev/null and b/assets/1f389.png differ diff --git a/assets/1f431.png b/assets/1f431.png new file mode 100644 index 0000000..ce018a0 Binary files /dev/null and b/assets/1f431.png differ diff --git a/assets/1f973.png b/assets/1f973.png new file mode 100644 index 0000000..0104196 Binary files /dev/null and b/assets/1f973.png differ diff --git a/assets/1f98b.png b/assets/1f98b.png new file mode 100644 index 0000000..27a9d93 Binary files /dev/null and b/assets/1f98b.png differ diff --git a/assets/Representative_Image.jpg b/assets/Representative_Image.jpg new file mode 100644 index 0000000..f59e861 Binary files /dev/null and b/assets/Representative_Image.jpg differ diff --git a/assets/gui.mp4 b/assets/gui.mp4 new file mode 100644 index 0000000..8e6ec78 Binary files /dev/null and b/assets/gui.mp4 differ diff --git a/assets/samples.mp4 b/assets/samples.mp4 new file mode 100644 index 0000000..4eb67a8 Binary files /dev/null and b/assets/samples.mp4 differ diff --git a/data_processing/.gitignore b/data_processing/.gitignore new file mode 100644 index 0000000..b188b7a --- /dev/null +++ b/data_processing/.gitignore @@ -0,0 +1,233 @@ +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +command + +samples/ +#demo + +HigherHRNet-Human-Pose-Estimation/models + +/demo/*.tar +demo/my_input +simple_HRNet/ +CID/ +process_input_images.py +# Custom +*.pkl +common/utils/human_model_files +data/*/data +data/*/annotations +data/*/images +data/*.npy +data/*/*.npy +output/ +*.pyc + +# C extensions +*.so + +.idea/* + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ + +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +### macOS template +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/dictionaries +.idea/**/shelf + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ +cmake-build-release/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +/simple-HRNet/ +/simple_HRNet/ + + +# detectron2 +detectron2/projects/DensePose/*.pkl + +# HRNet +/HigherHRNet-Human-Pose-Estimation/models/ + + +/low_resolution_data/ + +# yolo +/yolov5_crowdhuman/crowdhuman_yolov5m.pt + +# classifier +/classifier/*/*.pth + + +/common/utils/smplpytorch/smplpytorch/native/models/*.pkl \ No newline at end of file diff --git a/data_processing/.gitkeep b/data_processing/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/.gitignore b/data_processing/HigherHRNet-Human-Pose-Estimation/.gitignore new file mode 100644 index 0000000..c1b85b6 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/.gitignore @@ -0,0 +1,95 @@ +# IntelliJ project files +.idea +*.iml +out +gen + +### Vim template +[._]*.s[a-w][a-z] +[._]s[a-w][a-z] +*.un~ +Session.vim +.netrwhist +*~ + +### IPythonNotebook template +# Temporary data +.ipynb_checkpoints/ + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +#lib/ +#lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +*.ipynb +*.params +*.json +.vscode/ + +lib/pycocotools/_mask.c +lib/nms/cpu_nms.c + +output/* +models/* +log/* +data/* +external/ + +draws/ +plot/ + diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/LICENSE b/data_processing/HigherHRNet-Human-Pose-Estimation/LICENSE new file mode 100644 index 0000000..d4e6a8f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 HRNet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/README.md b/data_processing/HigherHRNet-Human-Pose-Estimation/README.md new file mode 100644 index 0000000..d07031b --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/README.md @@ -0,0 +1,272 @@ +# [HigherHRNet: Scale-Aware Representation Learning for Bottom-Up Human Pose Estimation (CVPR 2020)](https://arxiv.org/abs/1908.10357) + +## News +* \[2021/04/12\] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)! +* \[2020/07/05\] [A very nice blog](https://towardsdatascience.com/overview-of-human-pose-estimation-neural-networks-hrnet-higherhrnet-architectures-and-faq-1954b2f8b249) from Towards Data Science introducing HRNet and HigherHRNet for human pose estimation. +* \[2020/03/12\] Support train/test on the CrowdPose dataset. +* \[2020/02/24\] HigherHRNet is accepted to CVPR2020! +* \[2019/11/23\] Code and models for [HigherHRNet](https://arxiv.org/abs/1908.10357) are now released! +* \[2019/08/27\] HigherHRNet is now on [ArXiv](https://arxiv.org/abs/1908.10357). We will also release code and models, stay tuned! + +## Introduction +This is the official code of [HigherHRNet: Scale-Aware Representation Learning for Bottom-Up Human Pose Estimation](https://arxiv.org/abs/1908.10357). +Bottom-up human pose estimation methods have difficulties in predicting the correct pose for small persons due to challenges in scale variation. In this paper, we present **HigherHRNet**: a novel bottom-up human pose estimation method for learning scale-aware representations using high-resolution feature pyramids. Equipped with multi-resolution supervision for training and multi-resolution aggregation for inference, the proposed approach is able to solve the scale variation challenge in *bottom-up multi-person* pose estimation and localize keypoints more precisely, especially for small person. The feature pyramid in HigherHRNet consists of feature map outputs from HRNet and upsampled higher-resolution outputs through a transposed convolution. HigherHRNet outperforms the previous best bottom-up method by 2.5% AP for medium person on COCO test-dev, showing its effectiveness in handling scale variation. Furthermore, HigherHRNet achieves new state-of-the-art result on COCO test-dev (70.5% AP) without using refinement or other post-processing techniques, surpassing all existing bottom-up methods. HigherHRNet even surpasses all top-down methods on CrowdPose test (67.6% AP), suggesting its robustness in crowded scene. + +![Illustrating the architecture of the proposed Higher-HRNet](/figures/arch_v2.png) + +## Main Results +### Results on COCO val2017 without multi-scale test +| Method | Backbone | Input size | #Params | GFLOPs | AP | Ap .5 | AP .75 | AP (M) | AP (L) | +|--------------------|----------|------------|---------|--------|-------|-------|--------|--------|--------| +| HigherHRNet | HRNet-w32 | 512 | 28.6M | 47.9 | 67.1 | 86.2 | 73.0 | 61.5 | 76.1 | +| HigherHRNet | HRNet-w32 | 640 | 28.6M | 74.8 | 68.5 | 87.1 | 74.7 | 64.3 | 75.3 | +| HigherHRNet | HRNet-w48 | 640 | 63.8M | 154.3 | 69.9 | 87.2 | 76.1 | 65.4 | 76.4 | + +### Results on COCO val2017 *with* multi-scale test +| Method | Backbone | Input size | #Params | GFLOPs | AP | Ap .5 | AP .75 | AP (M) | AP (L) | +|--------------------|----------|------------|---------|--------|-------|-------|--------|--------|--------| +| HigherHRNet | HRNet-w32 | 512 | 28.6M | 47.9 | 69.9 | 87.1 | 76.0 | 65.3 | 77.0 | +| HigherHRNet | HRNet-w32 | 640 | 28.6M | 74.8 | 70.6 | 88.1 | 76.9 | 66.6 | 76.5 | +| HigherHRNet | HRNet-w48 | 640 | 63.8M | 154.3 | 72.1 | 88.4 | 78.2 | 67.8 | 78.3 | + +### Results on COCO test-dev2017 without multi-scale test +| Method | Backbone | Input size | #Params | GFLOPs | AP | Ap .5 | AP .75 | AP (M) | AP (L) | +|--------------------|----------|------------|---------|--------|-------|-------|--------|--------|--------| +| OpenPose\* | - | - | - | - | 61.8 | 84.9 | 67.5 | 57.1 | 68.2 | +| Hourglass | Hourglass | 512 | 277.8M | 206.9 | 56.6 | 81.8 | 61.8 | 49.8 | 67.0 | +| PersonLab | ResNet-152 | 1401 | 68.7M | 405.5 | 66.5 | 88.0 | 72.6 | 62.4 | 72.3 | +| PifPaf | - | - | - | - | 66.7 | - | - | 62.4 | 72.9 | +| Bottom-up HRNet | HRNet-w32 | 512 | 28.5M | 38.9 | 64.1 | 86.3 | 70.4 | 57.4 | 73.9 | +| **HigherHRNet** | HRNet-w32 | 512 | 28.6M | 47.9 | 66.4 | 87.5 | 72.8 | 61.2 | 74.2 | +| **HigherHRNet** | HRNet-w48 | 640 | 63.8M | 154.3 | **68.4** | **88.2** | **75.1** | **64.4** | **74.2** | + +### Results on COCO test-dev2017 *with* multi-scale test +| Method | Backbone | Input size | #Params | GFLOPs | AP | Ap .5 | AP .75 | AP (M) | AP (L) | +|--------------------|----------|------------|---------|--------|-------|-------|--------|--------|--------| +| Hourglass | Hourglass | 512 | 277.8M | 206.9 | 63.0 | 85.7 | 68.9 | 58.0 | 70.4 | +| Hourglass\* | Hourglass | 512 | 277.8M | 206.9 | 65.5 | 86.8 | 72.3 | 60.6 | 72.6 | +| PersonLab | ResNet-152 | 1401 | 68.7M | 405.5 | 68.7 | 89.0 | 75.4 | 64.1 | 75.5 | +| **HigherHRNet** | HRNet-w48 | 640 | 63.8M | 154.3 | **70.5** | **89.3** | **77.2** | **66.6** | **75.8** | + +### Results on CrowdPose test +| Method | AP | Ap .5 | AP .75 | AP (E) | AP (M) | AP (H) | +|--------------------|-------|-------|--------|--------|--------|--------| +| Mask-RCNN | 57.2 | 83.5 | 60.3 | 69.4 | 57.9 | 45.8 | +| AlphaPose | 61.0 | 81.3 | 66.0 | 71.2 | 61.4 | 51.1 | +| SPPE | 66.0. | 84.2 | 71.5 | 75.5 | 66.3 | 57.4 | +| OpenPose | - | - | - | 62.7 | 48.7 | 32.3 | +| **HigherHRNet** | 65.9 | 86.4 | 70.6 | 73.3 | 66.5 | 57.9 | +| **HigherHRNet+** | **67.6** | **87.4** | **72.6** | **75.8** | **68.1** | **58.9** | + +*Note: + indicates using multi-scale test.* + +## Environment +The code is developed using python 3.6 on Ubuntu 16.04. NVIDIA GPUs are needed. The code is developed and tested using 4 NVIDIA P100 GPU cards. Other platforms or GPU cards are not fully tested. + +## Quick start +### Installation +1. Install pytorch >= v1.1.0 following [official instruction](https://pytorch.org/). + - **Tested with pytorch v1.4.0** +2. Clone this repo, and we'll call the directory that you cloned as ${POSE_ROOT}. +3. Install dependencies: + ``` + pip install -r requirements.txt + ``` +4. Install [COCOAPI](https://github.com/cocodataset/cocoapi): + ``` + # COCOAPI=/path/to/clone/cocoapi + git clone https://github.com/cocodataset/cocoapi.git $COCOAPI + cd $COCOAPI/PythonAPI + # Install into global site-packages + make install + # Alternatively, if you do not have permissions or prefer + # not to install the COCO API into global site-packages + python3 setup.py install --user + ``` + Note that instructions like # COCOAPI=/path/to/install/cocoapi indicate that you should pick a path where you'd like to have the software cloned and then set an environment variable (COCOAPI in this case) accordingly. +5. Install [CrowdPoseAPI](https://github.com/Jeff-sjtu/CrowdPose) exactly the same as COCOAPI. + - **There is a bug in the CrowdPoseAPI, please reverse https://github.com/Jeff-sjtu/CrowdPose/commit/785e70d269a554b2ba29daf137354103221f479e** +6. Init output(training model output directory) and log(tensorboard log directory) directory: + + ``` + mkdir output + mkdir log + ``` + + Your directory tree should look like this: + + ``` + ${POSE_ROOT} + ├── data + ├── experiments + ├── lib + ├── log + ├── models + ├── output + ├── tools + ├── README.md + └── requirements.txt + ``` + +7. Download pretrained models from our model zoo([GoogleDrive](https://drive.google.com/open?id=1bdXVmYrSynPLSk5lptvgyQ8fhziobD50) or [OneDrive](https://1drv.ms/f/s!AhIXJn_J-blW4AwKRMklXVzndJT0)) + ``` + ${POSE_ROOT} + `-- models + `-- pytorch + |-- imagenet + | `-- hrnet_w32-36af842e.pth + `-- pose_coco + `-- pose_higher_hrnet_w32_512.pth + + ``` + +### Data preparation + +**For COCO data**, please download from [COCO download](http://cocodataset.org/#download), 2017 Train/Val is needed for COCO keypoints training and validation. +Download and extract them under {POSE_ROOT}/data, and make them look like this: +``` +${POSE_ROOT} +|-- data +`-- |-- coco + `-- |-- annotations + | |-- person_keypoints_train2017.json + | `-- person_keypoints_val2017.json + `-- images + |-- train2017 + | |-- 000000000009.jpg + | |-- 000000000025.jpg + | |-- 000000000030.jpg + | |-- ... + `-- val2017 + |-- 000000000139.jpg + |-- 000000000285.jpg + |-- 000000000632.jpg + |-- ... +``` + +**For CrowdPose data**, please download from [CrowdPose download](https://github.com/Jeff-sjtu/CrowdPose#dataset), Train/Val is needed for CrowdPose keypoints training and validation. +Download and extract them under {POSE_ROOT}/data, and make them look like this: +``` +${POSE_ROOT} +|-- data +`-- |-- crowd_pose + `-- |-- json + | |-- crowdpose_train.json + | |-- crowdpose_val.json + | |-- crowdpose_trainval.json (generated by tools/crowdpose_concat_train_val.py) + | `-- crowdpose_test.json + `-- images + |-- 100000.jpg + |-- 100001.jpg + |-- 100002.jpg + |-- 100003.jpg + |-- 100004.jpg + |-- 100005.jpg + |-- ... +``` +After downloading data, run `python tools/crowdpose_concat_train_val.py` under `${POSE_ROOT}` to create trainval set. + +### Training and Testing + +#### Testing on COCO val2017 dataset using model zoo's models ([GoogleDrive](https://drive.google.com/drive/folders/1X9-TzWpwbX2zQf2To8lB-ZQHMYviYYh6?usp=sharing)) + + +For single-scale testing: + +``` +python tools/valid.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml \ + TEST.MODEL_FILE models/pytorch/pose_coco/pose_higher_hrnet_w32_512.pth +``` + +By default, we use horizontal flip. To test without flip: + +``` +python tools/valid.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml \ + TEST.MODEL_FILE models/pytorch/pose_coco/pose_higher_hrnet_w32_512.pth \ + TEST.FLIP_TEST False +``` + +Multi-scale testing is also supported, although we do not report results in our paper: + +``` +python tools/valid.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml \ + TEST.MODEL_FILE models/pytorch/pose_coco/pose_higher_hrnet_w32_512.pth \ + TEST.SCALE_FACTOR '[0.5, 1.0, 2.0]' +``` + + +#### Training on COCO train2017 dataset + +``` +python tools/dist_train.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml +``` + +By default, it will use all available GPUs on the machine for training. To specify GPUs, use + +``` +CUDA_VISIBLE_DEVICES=0,1 python tools/dist_train.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml +``` + +#### Mixed-precision training +Due to large input size for bottom-up methods, we use mixed-precision training to train our Higher-HRNet by using the following command: +``` +python tools/dist_train.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml \ + FP16.ENABLED True FP16.DYNAMIC_LOSS_SCALE True +``` + +#### Synchronized BatchNorm training +If you have limited GPU memory, please try to reduce batch size and use SyncBN to train our Higher-HRNet by using the following command: +``` +python tools/dist_train.py \ + --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml \ + FP16.ENABLED True FP16.DYNAMIC_LOSS_SCALE True \ + MODEL.SYNC_BN True +``` + +Our code for mixed-precision training is borrowed from [NVIDIA Apex API](https://github.com/NVIDIA/apex). + +#### Training on CrowdPose trainval dataset + +``` +python tools/dist_train.py \ + --cfg experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3.yaml +``` + + +### Other applications +Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [Deep High-Resolution Representation Learning](https://jingdongwang2017.github.io/Projects/HRNet/). + +### Other implementations +[mmpose](https://github.com/open-mmlab/mmpose) + +## Citation +If you find this work or code is helpful in your research, please cite: +```` +@inproceedings{cheng2020bottom, + title={HigherHRNet: Scale-Aware Representation Learning for Bottom-Up Human Pose Estimation}, + author={Bowen Cheng and Bin Xiao and Jingdong Wang and Honghui Shi and Thomas S. Huang and Lei Zhang}, + booktitle={CVPR}, + year={2020} +} + +@inproceedings{SunXLW19, + title={Deep High-Resolution Representation Learning for Human Pose Estimation}, + author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang}, + booktitle={CVPR}, + year={2019} +} + +@article{wang2019deep, + title={Deep High-Resolution Representation Learning for Visual Recognition}, + author={Wang, Jingdong and Sun, Ke and Cheng, Tianheng and Jiang, Borui and Deng, Chaorui and Zhao, Yang and Liu, Dong and Mu, Yadong and Tan, Mingkui and Wang, Xinggang and Liu, Wenyu and Xiao, Bin}, + journal={TPAMI}, + year={2019} +} +```` + diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml new file mode 100644 index 0000000..c202448 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml @@ -0,0 +1,129 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: coco_kpt + DATASET_TEST: coco + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 512 + OUTPUT_SIZE: [128, 256] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 17 + ROOT: 'data/coco' + TEST: val2017 + TRAIN: train2017 +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 17 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_640_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_640_adam_lr1e-3.yaml new file mode 100644 index 0000000..d77e433 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w32_640_adam_lr1e-3.yaml @@ -0,0 +1,132 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +FP16: + ENABLED: True + DYNAMIC_LOSS_SCALE: True +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: coco_kpt + DATASET_TEST: coco + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 640 + OUTPUT_SIZE: [160, 320] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 17 + ROOT: 'data/coco' + TEST: val2017 + TRAIN: train2017 +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 17 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w48_640_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w48_640_adam_lr1e-3.yaml new file mode 100644 index 0000000..f259608 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/coco/higher_hrnet/w48_640_adam_lr1e-3.yaml @@ -0,0 +1,132 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +FP16: + ENABLED: True + DYNAMIC_LOSS_SCALE: True +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: coco_kpt + DATASET_TEST: coco + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 640 + OUTPUT_SIZE: [160, 320] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 17 + ROOT: 'data/coco' + TEST: val2017 + TRAIN: train2017 +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 48 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 17 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w48-8ef0771d.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 10 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3.yaml new file mode 100644 index 0000000..51cea9d --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3.yaml @@ -0,0 +1,129 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: crowd_pose_kpt + DATASET_TEST: crowd_pose + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 512 + OUTPUT_SIZE: [128, 256] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 14 + ROOT: 'data/crowd_pose' + TEST: test + TRAIN: trainval +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 14 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_coco.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_coco.yaml new file mode 100644 index 0000000..3309275 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_coco.yaml @@ -0,0 +1,140 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: crowd_pose_kpt + DATASET_TEST: crowd_pose + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 512 + OUTPUT_SIZE: [128, 256] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 14 + ROOT: 'data/crowd_pose' + TEST: test + TRAIN: trainval +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: + - 'conv1' + - 'bn1' + - 'conv2' + - 'bn2' + - 'layer1' + - 'transition1' + - 'stage2' + - 'transition2' + - 'stage3' + - 'transition3' + - 'stage4' + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 14 + PRETRAINED: 'models/pytorch/pose_coco/pose_higher_hrnet_w32_512.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_syncbn.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_syncbn.yaml new file mode 100644 index 0000000..8da9464 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_512_adam_lr1e-3_syncbn.yaml @@ -0,0 +1,130 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: crowd_pose_kpt + DATASET_TEST: crowd_pose + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 512 + OUTPUT_SIZE: [128, 256] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 14 + ROOT: 'data/crowd_pose' + TEST: test + TRAIN: trainval +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 14 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + TAG_PER_JOINT: True + SYNC_BN: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_640_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_640_adam_lr1e-3.yaml new file mode 100644 index 0000000..0b278ec --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w32_640_adam_lr1e-3.yaml @@ -0,0 +1,132 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +FP16: + ENABLED: True + DYNAMIC_LOSS_SCALE: True +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: crowd_pose_kpt + DATASET_TEST: crowd_pose + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 640 + OUTPUT_SIZE: [160, 320] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 14 + ROOT: 'data/crowd_pose' + TEST: test + TRAIN: trainval +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 32 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 14 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 12 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w48_640_adam_lr1e-3.yaml b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w48_640_adam_lr1e-3.yaml new file mode 100644 index 0000000..162941f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/experiments/crowd_pose/higher_hrnet/w48_640_adam_lr1e-3.yaml @@ -0,0 +1,132 @@ +AUTO_RESUME: True +DATA_DIR: '' +GPUS: (0,) +LOG_DIR: log +OUTPUT_DIR: output +PRINT_FREQ: 100 +FP16: + ENABLED: True + DYNAMIC_LOSS_SCALE: True +CUDNN: + BENCHMARK: True + DETERMINISTIC: False + ENABLED: True +DATASET: + SIGMA: 2 + DATASET: crowd_pose_kpt + DATASET_TEST: crowd_pose + DATA_FORMAT: jpg + FLIP: 0.5 + INPUT_SIZE: 640 + OUTPUT_SIZE: [160, 320] + MAX_NUM_PEOPLE: 30 + MAX_ROTATION: 30 + MAX_SCALE: 1.5 + SCALE_TYPE: 'short' + MAX_TRANSLATE: 40 + MIN_SCALE: 0.75 + NUM_JOINTS: 14 + ROOT: 'data/crowd_pose' + TEST: test + TRAIN: trainval +DEBUG: + DEBUG: True + SAVE_BATCH_IMAGES_GT: False + SAVE_BATCH_IMAGES_PRED: False + SAVE_HEATMAPS_GT: True + SAVE_HEATMAPS_PRED: True + SAVE_TAGMAPS_PRED: True +LOSS: + NUM_STAGES: 2 + AE_LOSS_TYPE: exp + WITH_AE_LOSS: [True, False] + PUSH_LOSS_FACTOR: [0.001, 0.001] + PULL_LOSS_FACTOR: [0.001, 0.001] + WITH_HEATMAPS_LOSS: [True, True] + HEATMAPS_LOSS_FACTOR: [1.0, 1.0] +MODEL: + EXTRA: + FINAL_CONV_KERNEL: 1 + PRETRAINED_LAYERS: ['*'] + STEM_INPLANES: 64 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 48 + - 96 + - 192 + - 384 + FUSE_METHOD: SUM + DECONV: + NUM_DECONVS: 1 + NUM_CHANNELS: + - 48 + KERNEL_SIZE: + - 4 + NUM_BASIC_BLOCKS: 4 + CAT_OUTPUT: + - True + INIT_WEIGHTS: True + NAME: pose_higher_hrnet + NUM_JOINTS: 14 + PRETRAINED: 'models/pytorch/imagenet/hrnet_w48-8ef0771d.pth' + TAG_PER_JOINT: True +TEST: + FLIP_TEST: True + IMAGES_PER_GPU: 1 + MODEL_FILE: '' + SCALE_FACTOR: [1] + DETECTION_THRESHOLD: 0.1 + WITH_HEATMAPS: (True, True) + WITH_AE: (True, False) + PROJECT2IMAGE: True + NMS_KERNEL: 5 + NMS_PADDING: 2 +TRAIN: + BEGIN_EPOCH: 0 + CHECKPOINT: '' + END_EPOCH: 300 + GAMMA1: 0.99 + GAMMA2: 0.0 + IMAGES_PER_GPU: 10 + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: [200, 260] + MOMENTUM: 0.9 + NESTEROV: False + OPTIMIZER: adam + RESUME: False + SHUFFLE: True + WD: 0.0001 +WORKERS: 4 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/figures/arch_v2.png b/data_processing/HigherHRNet-Human-Pose-Estimation/figures/arch_v2.png new file mode 100644 index 0000000..77aac93 Binary files /dev/null and b/data_processing/HigherHRNet-Human-Pose-Estimation/figures/arch_v2.png differ diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/__init__.py new file mode 100644 index 0000000..aeeb6ba --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from .default import _C as cfg +from .default import update_config +from .default import check_config diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/default.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/default.py new file mode 100644 index 0000000..0ecf29b --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/default.py @@ -0,0 +1,219 @@ + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from yacs.config import CfgNode as CN + +from .models import MODEL_EXTRAS + + +_C = CN() + +_C.OUTPUT_DIR = '' +_C.LOG_DIR = '' +_C.DATA_DIR = '' +_C.GPUS = (0,) +_C.WORKERS = 4 +_C.PRINT_FREQ = 20 +_C.AUTO_RESUME = False +_C.PIN_MEMORY = True +_C.RANK = 0 +_C.VERBOSE = True +_C.DIST_BACKEND = 'nccl' +_C.MULTIPROCESSING_DISTRIBUTED = True + +# FP16 training params +_C.FP16 = CN() +_C.FP16.ENABLED = False +_C.FP16.STATIC_LOSS_SCALE = 1.0 +_C.FP16.DYNAMIC_LOSS_SCALE = False + +# Cudnn related params +_C.CUDNN = CN() +_C.CUDNN.BENCHMARK = True +_C.CUDNN.DETERMINISTIC = False +_C.CUDNN.ENABLED = True + +# common params for NETWORK +_C.MODEL = CN() +_C.MODEL.NAME = 'pose_multi_resolution_net_v16' +_C.MODEL.INIT_WEIGHTS = True +_C.MODEL.PRETRAINED = '' +_C.MODEL.NUM_JOINTS = 17 +_C.MODEL.TAG_PER_JOINT = True +_C.MODEL.EXTRA = CN(new_allowed=True) +_C.MODEL.SYNC_BN = False + +_C.LOSS = CN() +_C.LOSS.NUM_STAGES = 1 +_C.LOSS.WITH_HEATMAPS_LOSS = (True,) +_C.LOSS.HEATMAPS_LOSS_FACTOR = (1.0,) +_C.LOSS.WITH_AE_LOSS = (True,) +_C.LOSS.AE_LOSS_TYPE = 'max' +_C.LOSS.PUSH_LOSS_FACTOR = (0.001,) +_C.LOSS.PULL_LOSS_FACTOR = (0.001,) + +# DATASET related params +_C.DATASET = CN() +_C.DATASET.ROOT = '' +_C.DATASET.DATASET = 'coco_kpt' +_C.DATASET.DATASET_TEST = 'coco' +_C.DATASET.NUM_JOINTS = 17 +_C.DATASET.MAX_NUM_PEOPLE = 30 +_C.DATASET.TRAIN = 'train2017' +_C.DATASET.TEST = 'val2017' +_C.DATASET.DATA_FORMAT = 'jpg' + +# training data augmentation +_C.DATASET.MAX_ROTATION = 30 +_C.DATASET.MIN_SCALE = 0.75 +_C.DATASET.MAX_SCALE = 1.25 +_C.DATASET.SCALE_TYPE = 'short' +_C.DATASET.MAX_TRANSLATE = 40 +_C.DATASET.INPUT_SIZE = 512 +_C.DATASET.OUTPUT_SIZE = [128, 256, 512] +_C.DATASET.FLIP = 0.5 + +# heatmap generator (default is OUTPUT_SIZE/64) +_C.DATASET.SIGMA = -1 +_C.DATASET.SCALE_AWARE_SIGMA = False +_C.DATASET.BASE_SIZE = 256.0 +_C.DATASET.BASE_SIGMA = 2.0 +_C.DATASET.INT_SIGMA = False + +_C.DATASET.WITH_CENTER = False + +# train +_C.TRAIN = CN() + +_C.TRAIN.LR_FACTOR = 0.1 +_C.TRAIN.LR_STEP = [90, 110] +_C.TRAIN.LR = 0.001 + +_C.TRAIN.OPTIMIZER = 'adam' +_C.TRAIN.MOMENTUM = 0.9 +_C.TRAIN.WD = 0.0001 +_C.TRAIN.NESTEROV = False +_C.TRAIN.GAMMA1 = 0.99 +_C.TRAIN.GAMMA2 = 0.0 + +_C.TRAIN.BEGIN_EPOCH = 0 +_C.TRAIN.END_EPOCH = 140 + +_C.TRAIN.RESUME = False +_C.TRAIN.CHECKPOINT = '' + +_C.TRAIN.IMAGES_PER_GPU = 32 +_C.TRAIN.SHUFFLE = True + +# testing +_C.TEST = CN() + +# size of images for each device +# _C.TEST.BATCH_SIZE = 32 +_C.TEST.IMAGES_PER_GPU = 32 +# Test Model Epoch +_C.TEST.FLIP_TEST = False +_C.TEST.ADJUST = True +_C.TEST.REFINE = True +_C.TEST.SCALE_FACTOR = [1] +# group +_C.TEST.DETECTION_THRESHOLD = 0.2 +_C.TEST.TAG_THRESHOLD = 1. +_C.TEST.USE_DETECTION_VAL = True +_C.TEST.IGNORE_TOO_MUCH = False +_C.TEST.MODEL_FILE = '' +_C.TEST.IGNORE_CENTER = True +_C.TEST.NMS_KERNEL = 3 +_C.TEST.NMS_PADDING = 1 +_C.TEST.PROJECT2IMAGE = False + +_C.TEST.WITH_HEATMAPS = (True,) +_C.TEST.WITH_AE = (True,) + +_C.TEST.LOG_PROGRESS = False + +# debug +_C.DEBUG = CN() +_C.DEBUG.DEBUG = True +_C.DEBUG.SAVE_BATCH_IMAGES_GT = False +_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False +_C.DEBUG.SAVE_HEATMAPS_GT = True +_C.DEBUG.SAVE_HEATMAPS_PRED = True +_C.DEBUG.SAVE_TAGMAPS_PRED = True + + +def update_config(cfg, args): + cfg.defrost() + cfg.merge_from_file(args.cfg) + cfg.merge_from_list(args.opts) + + if not os.path.exists(cfg.DATASET.ROOT): + cfg.DATASET.ROOT = os.path.join( + cfg.DATA_DIR, cfg.DATASET.ROOT + ) + + cfg.MODEL.PRETRAINED = os.path.join( + cfg.DATA_DIR, cfg.MODEL.PRETRAINED + ) + + if cfg.TEST.MODEL_FILE: + cfg.TEST.MODEL_FILE = os.path.join( + cfg.DATA_DIR, cfg.TEST.MODEL_FILE + ) + + if cfg.DATASET.WITH_CENTER: + cfg.DATASET.NUM_JOINTS += 1 + cfg.MODEL.NUM_JOINTS = cfg.DATASET.NUM_JOINTS + + if not isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)): + cfg.DATASET.OUTPUT_SIZE = [cfg.DATASET.OUTPUT_SIZE] + if not isinstance(cfg.LOSS.WITH_HEATMAPS_LOSS, (list, tuple)): + cfg.LOSS.WITH_HEATMAPS_LOSS = (cfg.LOSS.WITH_HEATMAPS_LOSS) + + if not isinstance(cfg.LOSS.HEATMAPS_LOSS_FACTOR, (list, tuple)): + cfg.LOSS.HEATMAPS_LOSS_FACTOR = (cfg.LOSS.HEATMAPS_LOSS_FACTOR) + + if not isinstance(cfg.LOSS.WITH_AE_LOSS, (list, tuple)): + cfg.LOSS.WITH_AE_LOSS = (cfg.LOSS.WITH_AE_LOSS) + + if not isinstance(cfg.LOSS.PUSH_LOSS_FACTOR, (list, tuple)): + cfg.LOSS.PUSH_LOSS_FACTOR = (cfg.LOSS.PUSH_LOSS_FACTOR) + + if not isinstance(cfg.LOSS.PULL_LOSS_FACTOR, (list, tuple)): + cfg.LOSS.PULL_LOSS_FACTOR = (cfg.LOSS.PULL_LOSS_FACTOR) + + cfg.freeze() + + +def check_config(cfg): + assert cfg.LOSS.NUM_STAGES == len(cfg.LOSS.WITH_HEATMAPS_LOSS), \ + 'LOSS.NUM_SCALE should be the same as the length of LOSS.WITH_HEATMAPS_LOSS' + assert cfg.LOSS.NUM_STAGES == len(cfg.LOSS.HEATMAPS_LOSS_FACTOR), \ + 'LOSS.NUM_SCALE should be the same as the length of LOSS.HEATMAPS_LOSS_FACTOR' + assert cfg.LOSS.NUM_STAGES == len(cfg.LOSS.WITH_AE_LOSS), \ + 'LOSS.NUM_SCALE should be the same as the length of LOSS.WITH_AE_LOSS' + assert cfg.LOSS.NUM_STAGES == len(cfg.LOSS.PUSH_LOSS_FACTOR), \ + 'LOSS.NUM_SCALE should be the same as the length of LOSS.PUSH_LOSS_FACTOR' + assert cfg.LOSS.NUM_STAGES == len(cfg.LOSS.PULL_LOSS_FACTOR), \ + 'LOSS.NUM_SCALE should be the same as the length of LOSS.PULL_LOSS_FACTOR' + assert cfg.LOSS.NUM_STAGES == len(cfg.TEST.WITH_HEATMAPS), \ + 'LOSS.NUM_SCALE should be the same as the length of TEST.WITH_HEATMAPS' + assert cfg.LOSS.NUM_STAGES == len(cfg.TEST.WITH_AE), \ + 'LOSS.NUM_SCALE should be the same as the length of TEST.WITH_AE' + + +if __name__ == '__main__': + import sys + with open(sys.argv[1], 'w') as f: + print(_C, file=f) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/models.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/models.py new file mode 100644 index 0000000..f8d258f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/config/models.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + + +# pose_multi_resoluton_net related params +POSE_HIGHER_RESOLUTION_NET = CN() +POSE_HIGHER_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] +POSE_HIGHER_RESOLUTION_NET.STEM_INPLANES = 64 +POSE_HIGHER_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 + +POSE_HIGHER_RESOLUTION_NET.STAGE1 = CN() +POSE_HIGHER_RESOLUTION_NET.STAGE1.NUM_MODULES = 1 +POSE_HIGHER_RESOLUTION_NET.STAGE1.NUM_BRANCHES = 1 +POSE_HIGHER_RESOLUTION_NET.STAGE1.NUM_BLOCKS = [4] +POSE_HIGHER_RESOLUTION_NET.STAGE1.NUM_CHANNELS = [64] +POSE_HIGHER_RESOLUTION_NET.STAGE1.BLOCK = 'BOTTLENECK' +POSE_HIGHER_RESOLUTION_NET.STAGE1.FUSE_METHOD = 'SUM' + +POSE_HIGHER_RESOLUTION_NET.STAGE2 = CN() +POSE_HIGHER_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 +POSE_HIGHER_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 +POSE_HIGHER_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] +POSE_HIGHER_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [24, 48] +POSE_HIGHER_RESOLUTION_NET.STAGE2.BLOCK = 'BOTTLENECK' +POSE_HIGHER_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' + +POSE_HIGHER_RESOLUTION_NET.STAGE3 = CN() +POSE_HIGHER_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 +POSE_HIGHER_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 +POSE_HIGHER_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] +POSE_HIGHER_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [24, 48, 92] +POSE_HIGHER_RESOLUTION_NET.STAGE3.BLOCK = 'BOTTLENECK' +POSE_HIGHER_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' + +POSE_HIGHER_RESOLUTION_NET.STAGE4 = CN() +POSE_HIGHER_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 +POSE_HIGHER_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 +POSE_HIGHER_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +POSE_HIGHER_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [24, 48, 92, 192] +POSE_HIGHER_RESOLUTION_NET.STAGE4.BLOCK = 'BOTTLENECK' +POSE_HIGHER_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' + +POSE_HIGHER_RESOLUTION_NET.DECONV = CN() +POSE_HIGHER_RESOLUTION_NET.DECONV.NUM_DCONVS = 2 +POSE_HIGHER_RESOLUTION_NET.DECONV.NUM_CHANNELS = [32, 32] +POSE_HIGHER_RESOLUTION_NET.DECONV.NUM_BASIC_BLOCKS = 4 +POSE_HIGHER_RESOLUTION_NET.DECONV.KERNEL_SIZE = [2, 2] +POSE_HIGHER_RESOLUTION_NET.DECONV.CAT_OUTPUT = [True, True] + + +MODEL_EXTRAS = { + 'pose_multi_resolution_net_v16': POSE_HIGHER_RESOLUTION_NET, +} diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/group.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/group.py new file mode 100644 index 0000000..f09be8f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/group.py @@ -0,0 +1,283 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Some code is from https://github.com/princeton-vl/pose-ae-train/blob/454d4ba113bbb9775d4dc259ef5e6c07c2ceed54/utils/group.py +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from munkres import Munkres +import numpy as np +import torch + + +def py_max_match(scores): + m = Munkres() + tmp = m.compute(scores) + tmp = np.array(tmp).astype(np.int32) + return tmp + + +def match_by_tag(inp, params): + assert isinstance(params, Params), 'params should be class Params()' + + tag_k, loc_k, val_k = inp + default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2])) + + joint_dict = {} + tag_dict = {} + for i in range(params.num_joints): + idx = params.joint_order[i] + + tags = tag_k[idx] + joints = np.concatenate( + (loc_k[idx], val_k[idx, :, None], tags), 1 + ) + mask = joints[:, 2] > params.detection_threshold + tags = tags[mask] + joints = joints[mask] + + if joints.shape[0] == 0: + continue + + if i == 0 or len(joint_dict) == 0: + for tag, joint in zip(tags, joints): + key = tag[0] + joint_dict.setdefault(key, np.copy(default_))[idx] = joint + tag_dict[key] = [tag] + else: + grouped_keys = list(joint_dict.keys())[:params.max_num_people] + grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys] + + if params.ignore_too_much \ + and len(grouped_keys) == params.max_num_people: + continue + + diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :] + diff_normed = np.linalg.norm(diff, ord=2, axis=2) + diff_saved = np.copy(diff_normed) + + if params.use_detection_val: + diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3] + + num_added = diff.shape[0] + num_grouped = diff.shape[1] + + if num_added > num_grouped: + diff_normed = np.concatenate( + ( + diff_normed, + np.zeros((num_added, num_added-num_grouped))+1e10 + ), + axis=1 + ) + + pairs = py_max_match(diff_normed) + for row, col in pairs: + if ( + row < num_added + and col < num_grouped + and diff_saved[row][col] < params.tag_threshold + ): + key = grouped_keys[col] + joint_dict[key][idx] = joints[row] + tag_dict[key].append(tags[row]) + else: + key = tags[row][0] + joint_dict.setdefault(key, np.copy(default_))[idx] = \ + joints[row] + tag_dict[key] = [tags[row]] + + ans = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32) + return ans + + +class Params(object): + def __init__(self, cfg): + self.num_joints = cfg.DATASET.NUM_JOINTS + self.max_num_people = cfg.DATASET.MAX_NUM_PEOPLE + + self.detection_threshold = cfg.TEST.DETECTION_THRESHOLD + self.tag_threshold = cfg.TEST.TAG_THRESHOLD + self.use_detection_val = cfg.TEST.USE_DETECTION_VAL + self.ignore_too_much = cfg.TEST.IGNORE_TOO_MUCH + + if cfg.DATASET.WITH_CENTER and cfg.TEST.IGNORE_CENTER: + self.num_joints -= 1 + + if cfg.DATASET.WITH_CENTER and not cfg.TEST.IGNORE_CENTER: + self.joint_order = [ + i-1 for i in [18, 1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17] + ] + else: + self.joint_order = [ + i-1 for i in [1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17] + ] + + +class HeatmapParser(object): + def __init__(self, cfg): + self.params = Params(cfg) + self.tag_per_joint = cfg.MODEL.TAG_PER_JOINT + self.pool = torch.nn.MaxPool2d( + cfg.TEST.NMS_KERNEL, 1, cfg.TEST.NMS_PADDING + ) + + def nms(self, det): + maxm = self.pool(det) + maxm = torch.eq(maxm, det).float() + det = det * maxm + return det + + def match(self, tag_k, loc_k, val_k): + match = lambda x: match_by_tag(x, self.params) + return list(map(match, zip(tag_k, loc_k, val_k))) + + def top_k(self, det, tag): + # det = torch.Tensor(det, requires_grad=False) + # tag = torch.Tensor(tag, requires_grad=False) + + det = self.nms(det) + num_images = det.size(0) + num_joints = det.size(1) + h = det.size(2) + w = det.size(3) + det = det.view(num_images, num_joints, -1) + val_k, ind = det.topk(self.params.max_num_people, dim=2) + + tag = tag.view(tag.size(0), tag.size(1), w*h, -1) + if not self.tag_per_joint: + tag = tag.expand(-1, self.params.num_joints, -1, -1) + + tag_k = torch.stack( + [ + torch.gather(tag[:, :, :, i], 2, ind) + for i in range(tag.size(3)) + ], + dim=3 + ) + + x = ind % w + y = (ind / w).long() + + ind_k = torch.stack((x, y), dim=3) + + ans = { + 'tag_k': tag_k.cpu().numpy(), + 'loc_k': ind_k.cpu().numpy(), + 'val_k': val_k.cpu().numpy() + } + + return ans + + def adjust(self, ans, det): + for batch_id, people in enumerate(ans): + for people_id, i in enumerate(people): + for joint_id, joint in enumerate(i): + if joint[2] > 0: + y, x = joint[0:2] + xx, yy = int(x), int(y) + #print(batch_id, joint_id, det[batch_id].shape) + tmp = det[batch_id][joint_id] + if tmp[xx, min(yy+1, tmp.shape[1]-1)] > tmp[xx, max(yy-1, 0)]: + y += 0.25 + else: + y -= 0.25 + + if tmp[min(xx+1, tmp.shape[0]-1), yy] > tmp[max(0, xx-1), yy]: + x += 0.25 + else: + x -= 0.25 + ans[batch_id][people_id, joint_id, 0:2] = (y+0.5, x+0.5) + return ans + + def refine(self, det, tag, keypoints): + """ + Given initial keypoint predictions, we identify missing joints + :param det: numpy.ndarray of size (17, 128, 128) + :param tag: numpy.ndarray of size (17, 128, 128) if not flip + :param keypoints: numpy.ndarray of size (17, 4) if not flip, last dim is (x, y, det score, tag score) + :return: + """ + if len(tag.shape) == 3: + # tag shape: (17, 128, 128, 1) + tag = tag[:, :, :, None] + + tags = [] + for i in range(keypoints.shape[0]): + if keypoints[i, 2] > 0: + # save tag value of detected keypoint + x, y = keypoints[i][:2].astype(np.int32) + tags.append(tag[i, y, x]) + + # mean tag of current detected people + prev_tag = np.mean(tags, axis=0) + ans = [] + + for i in range(keypoints.shape[0]): + # score of joints i at all position + tmp = det[i, :, :] + # distance of all tag values with mean tag of current detected people + tt = (((tag[i, :, :] - prev_tag[None, None, :]) ** 2).sum(axis=2) ** 0.5) + tmp2 = tmp - np.round(tt) + + # find maximum position + y, x = np.unravel_index(np.argmax(tmp2), tmp.shape) + xx = x + yy = y + # detection score at maximum position + val = tmp[y, x] + # offset by 0.5 + x += 0.5 + y += 0.5 + + # add a quarter offset + if tmp[yy, min(xx + 1, tmp.shape[1] - 1)] > tmp[yy, max(xx - 1, 0)]: + x += 0.25 + else: + x -= 0.25 + + if tmp[min(yy + 1, tmp.shape[0] - 1), xx] > tmp[max(0, yy - 1), xx]: + y += 0.25 + else: + y -= 0.25 + + ans.append((x, y, val)) + ans = np.array(ans) + + if ans is not None: + for i in range(det.shape[0]): + # add keypoint if it is not detected + if ans[i, 2] > 0 and keypoints[i, 2] == 0: + # if ans[i, 2] > 0.01 and keypoints[i, 2] == 0: + keypoints[i, :2] = ans[i, :2] + keypoints[i, 2] = ans[i, 2] + + return keypoints + + def parse(self, det, tag, adjust=True, refine=True): + ans = self.match(**self.top_k(det, tag)) + + if adjust: + ans = self.adjust(ans, det) + + scores = [i[:, 2].mean() for i in ans[0]] + + if refine: + ans = ans[0] + # for every detected person + for i in range(len(ans)): + det_numpy = det[0].cpu().numpy() + tag_numpy = tag[0].cpu().numpy() + if not self.tag_per_joint: + tag_numpy = np.tile( + tag_numpy, (self.params.num_joints, 1, 1, 1) + ) + ans[i] = self.refine(det_numpy, tag_numpy, ans[i]) + ans = [ans] + + return ans, scores diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/inference.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/inference.py new file mode 100644 index 0000000..fbc427e --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/inference.py @@ -0,0 +1,208 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import torch + +from dataset.transforms import FLIP_CONFIG + + +def get_outputs( + cfg, model, image, with_flip=False, + project2image=False, size_projected=None +): + outputs = [] + heatmaps = [] + tags = [] + + outputs.append(model(image)) + heatmaps.append(outputs[-1][:, :cfg.DATASET.NUM_JOINTS]) + tags.append(outputs[-1][:, cfg.DATASET.NUM_JOINTS:]) + + if with_flip: + outputs.append(model(torch.flip(image, [3]))) + outputs[-1] = torch.flip(outputs[-1], [3]) + heatmaps.append(outputs[-1][:, :cfg.DATASET.NUM_JOINTS]) + tags.append(outputs[-1][:, cfg.DATASET.NUM_JOINTS:]) + if 'coco' in cfg.DATASET.DATASET: + dataset_name = 'COCO' + elif 'crowd_pose' in cfg.DATASET.DATASET: + dataset_name = 'CROWDPOSE' + else: + raise ValueError('Please implement flip_index for new dataset: %s.' % cfg.DATASET.DATASET) + flip_index = FLIP_CONFIG[dataset_name + '_WITH_CENTER'] \ + if cfg.DATASET.WITH_CENTER else FLIP_CONFIG[dataset_name] + heatmaps[-1] = heatmaps[-1][:, flip_index, :, :] + if cfg.MODEL.TAG_PER_JOINT: + tags[-1] = tags[-1][:, flip_index, :, :] + + if cfg.DATASET.WITH_CENTER and cfg.TEST.IGNORE_CENTER: + heatmaps = [hms[:, :-1] for hms in heatmaps] + tags = [tms[:, :-1] for tms in tags] + + if project2image and size_projected: + heatmaps = [ + torch.nn.functional.interpolate( + hms, + size=(size_projected[1], size_projected[0]), + mode='bilinear', + align_corners=False + ) + for hms in heatmaps + ] + + tags = [ + torch.nn.functional.interpolate( + tms, + size=(size_projected[1], size_projected[0]), + mode='bilinear', + align_corners=False + ) + for tms in tags + ] + + return outputs, heatmaps, tags + + +def get_multi_stage_outputs( + cfg, model, image, with_flip=False, + project2image=False, size_projected=None +): + # outputs = [] + heatmaps_avg = 0 + num_heatmaps = 0 + heatmaps = [] + tags = [] + + outputs = model(image) + for i, output in enumerate(outputs): + if len(outputs) > 1 and i != len(outputs) - 1: + output = torch.nn.functional.interpolate( + output, + size=(outputs[-1].size(2), outputs[-1].size(3)), + mode='bilinear', + align_corners=False + ) + + offset_feat = cfg.DATASET.NUM_JOINTS \ + if cfg.LOSS.WITH_HEATMAPS_LOSS[i] else 0 + + if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]: + heatmaps_avg += output[:, :cfg.DATASET.NUM_JOINTS] + num_heatmaps += 1 + + if cfg.LOSS.WITH_AE_LOSS[i] and cfg.TEST.WITH_AE[i]: + tags.append(output[:, offset_feat:]) + + if num_heatmaps > 0: + heatmaps.append(heatmaps_avg/num_heatmaps) + + if with_flip: + if 'coco' in cfg.DATASET.DATASET: + dataset_name = 'COCO' + elif 'crowd_pose' in cfg.DATASET.DATASET: + dataset_name = 'CROWDPOSE' + else: + raise ValueError('Please implement flip_index for new dataset: %s.' % cfg.DATASET.DATASET) + flip_index = FLIP_CONFIG[dataset_name + '_WITH_CENTER'] \ + if cfg.DATASET.WITH_CENTER else FLIP_CONFIG[dataset_name] + + heatmaps_avg = 0 + num_heatmaps = 0 + outputs_flip = model(torch.flip(image, [3])) + for i in range(len(outputs_flip)): + output = outputs_flip[i] + if len(outputs_flip) > 1 and i != len(outputs_flip) - 1: + output = torch.nn.functional.interpolate( + output, + size=(outputs_flip[-1].size(2), outputs_flip[-1].size(3)), + mode='bilinear', + align_corners=False + ) + output = torch.flip(output, [3]) + outputs.append(output) + + offset_feat = cfg.DATASET.NUM_JOINTS \ + if cfg.LOSS.WITH_HEATMAPS_LOSS[i] else 0 + + if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]: + heatmaps_avg += \ + output[:, :cfg.DATASET.NUM_JOINTS][:, flip_index, :, :] + num_heatmaps += 1 + + if cfg.LOSS.WITH_AE_LOSS[i] and cfg.TEST.WITH_AE[i]: + tags.append(output[:, offset_feat:]) + if cfg.MODEL.TAG_PER_JOINT: + tags[-1] = tags[-1][:, flip_index, :, :] + + heatmaps.append(heatmaps_avg/num_heatmaps) + + if cfg.DATASET.WITH_CENTER and cfg.TEST.IGNORE_CENTER: + heatmaps = [hms[:, :-1] for hms in heatmaps] + tags = [tms[:, :-1] for tms in tags] + + if project2image and size_projected: + heatmaps = [ + torch.nn.functional.interpolate( + hms, + size=(size_projected[1], size_projected[0]), + mode='bilinear', + align_corners=False + ) + for hms in heatmaps + ] + + tags = [ + torch.nn.functional.interpolate( + tms, + size=(size_projected[1], size_projected[0]), + mode='bilinear', + align_corners=False + ) + for tms in tags + ] + + return outputs, heatmaps, tags + + +def aggregate_results( + cfg, scale_factor, final_heatmaps, tags_list, heatmaps, tags +): + if scale_factor == 1 or len(cfg.TEST.SCALE_FACTOR) == 1: + if final_heatmaps is not None and not cfg.TEST.PROJECT2IMAGE: + tags = [ + torch.nn.functional.interpolate( + tms, + size=(final_heatmaps.size(2), final_heatmaps.size(3)), + mode='bilinear', + align_corners=False + ) + for tms in tags + ] + for tms in tags: + tags_list.append(torch.unsqueeze(tms, dim=4)) + + heatmaps_avg = (heatmaps[0] + heatmaps[1])/2.0 if cfg.TEST.FLIP_TEST \ + else heatmaps[0] + + if final_heatmaps is None: + final_heatmaps = heatmaps_avg + elif cfg.TEST.PROJECT2IMAGE: + final_heatmaps += heatmaps_avg + else: + final_heatmaps += torch.nn.functional.interpolate( + heatmaps_avg, + size=(final_heatmaps.size(2), final_heatmaps.size(3)), + mode='bilinear', + align_corners=False + ) + + return final_heatmaps, tags_list diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/loss.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/loss.py new file mode 100644 index 0000000..ab580bb --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/loss.py @@ -0,0 +1,324 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import torch +import torch.nn as nn + + +logger = logging.getLogger(__name__) + + +def make_input(t, requires_grad=False, need_cuda=True): + inp = torch.autograd.Variable(t, requires_grad=requires_grad) + inp = inp.sum() + if need_cuda: + inp = inp.cuda() + return inp + + +class HeatmapLoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, pred, gt, mask): + assert pred.size() == gt.size() + loss = ((pred - gt)**2) * mask[:, None, :, :].expand_as(pred) + loss = loss.mean(dim=3).mean(dim=2).mean(dim=1) + # loss = loss.mean(dim=3).mean(dim=2).sum(dim=1) + return loss + + +class AELoss(nn.Module): + def __init__(self, loss_type): + super().__init__() + self.loss_type = loss_type + + def singleTagLoss(self, pred_tag, joints): + """ + associative embedding loss for one image + """ + tags = [] + pull = 0 + for joints_per_person in joints: + tmp = [] + for joint in joints_per_person: + if joint[1] > 0: + tmp.append(pred_tag[joint[0]]) + if len(tmp) == 0: + continue + tmp = torch.stack(tmp) + tags.append(torch.mean(tmp, dim=0)) + pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2) + + num_tags = len(tags) + if num_tags == 0: + return make_input(torch.zeros(1).float()), \ + make_input(torch.zeros(1).float()) + elif num_tags == 1: + return make_input(torch.zeros(1).float()), \ + pull/(num_tags) + + tags = torch.stack(tags) + + size = (num_tags, num_tags) + A = tags.expand(*size) + B = A.permute(1, 0) + + diff = A - B + + if self.loss_type == 'exp': + diff = torch.pow(diff, 2) + push = torch.exp(-diff) + push = torch.sum(push) - num_tags + elif self.loss_type == 'max': + diff = 1 - torch.abs(diff) + push = torch.clamp(diff, min=0).sum() - num_tags + else: + raise ValueError('Unkown ae loss type') + + return push/((num_tags - 1) * num_tags) * 0.5, \ + pull/(num_tags) + + def forward(self, tags, joints): + """ + accumulate the tag loss for each image in the batch + """ + pushes, pulls = [], [] + joints = joints.cpu().data.numpy() + batch_size = tags.size(0) + for i in range(batch_size): + push, pull = self.singleTagLoss(tags[i], joints[i]) + pushes.append(push) + pulls.append(pull) + return torch.stack(pushes), torch.stack(pulls) + + +class JointsMSELoss(nn.Module): + def __init__(self, use_target_weight): + super(JointsMSELoss, self).__init__() + self.criterion = nn.MSELoss(size_average=True) + self.use_target_weight = use_target_weight + + def forward(self, output, target, target_weight): + batch_size = output.size(0) + num_joints = output.size(1) + heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) + heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) + loss = 0 + + for idx in range(num_joints): + heatmap_pred = heatmaps_pred[idx].squeeze() + heatmap_gt = heatmaps_gt[idx].squeeze() + if self.use_target_weight: + loss += 0.5 * self.criterion( + heatmap_pred.mul(target_weight[:, idx]), + heatmap_gt.mul(target_weight[:, idx]) + ) + else: + loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) + + return loss / num_joints + + +class LossFactory(nn.Module): + def __init__(self, cfg): + super().__init__() + self.num_joints = cfg.DATASET.NUM_JOINTS + self.heatmaps_loss = None + self.ae_loss = None + self.heatmaps_loss_factor = 1.0 + self.push_loss_factor = 1.0 + self.pull_loss_factor = 1.0 + + if cfg.LOSS.WITH_HEATMAPS_LOSS: + self.heatmaps_loss = HeatmapLoss() + self.heatmaps_loss_factor = cfg.LOSS.HEATMAPS_LOSS_FACTOR + if cfg.LOSS.WITH_AE_LOSS: + self.ae_loss = AELoss(cfg.LOSS.AE_LOSS_TYPE) + self.push_loss_factor = cfg.LOSS.PUSH_LOSS_FACTOR + self.pull_loss_factor = cfg.LOSS.PULL_LOSS_FACTOR + + if not self.heatmaps_loss and not self.ae_loss: + logger.error('At least enable one loss!') + + def forward(self, outputs, heatmaps, masks, joints): + # TODO(bowen): outputs and heatmaps can be lists of same length + heatmaps_pred = outputs[:, :self.num_joints] + tags_pred = outputs[:, self.num_joints:] + + heatmaps_loss = None + push_loss = None + pull_loss = None + + if self.heatmaps_loss is not None: + heatmaps_loss = self.heatmaps_loss(heatmaps_pred, heatmaps, masks) + heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor + + if self.ae_loss is not None: + batch_size = tags_pred.size()[0] + tags_pred = tags_pred.contiguous().view(batch_size, -1, 1) + + push_loss, pull_loss = self.ae_loss(tags_pred, joints) + push_loss = push_loss * self.push_loss_factor + pull_loss = pull_loss * self.pull_loss_factor + + return [heatmaps_loss], [push_loss], [pull_loss] + + +class MultiLossFactory(nn.Module): + def __init__(self, cfg): + super().__init__() + # init check + self._init_check(cfg) + + self.num_joints = cfg.MODEL.NUM_JOINTS + self.num_stages = cfg.LOSS.NUM_STAGES + + self.heatmaps_loss = \ + nn.ModuleList( + [ + HeatmapLoss() + if with_heatmaps_loss else None + for with_heatmaps_loss in cfg.LOSS.WITH_HEATMAPS_LOSS + ] + ) + self.heatmaps_loss_factor = cfg.LOSS.HEATMAPS_LOSS_FACTOR + + self.ae_loss = \ + nn.ModuleList( + [ + AELoss(cfg.LOSS.AE_LOSS_TYPE) if with_ae_loss else None + for with_ae_loss in cfg.LOSS.WITH_AE_LOSS + ] + ) + self.push_loss_factor = cfg.LOSS.PUSH_LOSS_FACTOR + self.pull_loss_factor = cfg.LOSS.PULL_LOSS_FACTOR + + def forward(self, outputs, heatmaps, masks, joints): + # forward check + self._forward_check(outputs, heatmaps, masks, joints) + + heatmaps_losses = [] + push_losses = [] + pull_losses = [] + for idx in range(len(outputs)): + offset_feat = 0 + if self.heatmaps_loss[idx]: + heatmaps_pred = outputs[idx][:, :self.num_joints] + offset_feat = self.num_joints + + heatmaps_loss = self.heatmaps_loss[idx]( + heatmaps_pred, heatmaps[idx], masks[idx] + ) + heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor[idx] + heatmaps_losses.append(heatmaps_loss) + else: + heatmaps_losses.append(None) + + if self.ae_loss[idx]: + tags_pred = outputs[idx][:, offset_feat:] + batch_size = tags_pred.size()[0] + tags_pred = tags_pred.contiguous().view(batch_size, -1, 1) + + push_loss, pull_loss = self.ae_loss[idx]( + tags_pred, joints[idx] + ) + push_loss = push_loss * self.push_loss_factor[idx] + pull_loss = pull_loss * self.pull_loss_factor[idx] + + push_losses.append(push_loss) + pull_losses.append(pull_loss) + else: + push_losses.append(None) + pull_losses.append(None) + + return heatmaps_losses, push_losses, pull_losses + + def _init_check(self, cfg): + assert isinstance(cfg.LOSS.WITH_HEATMAPS_LOSS, (list, tuple)), \ + 'LOSS.WITH_HEATMAPS_LOSS should be a list or tuple' + assert isinstance(cfg.LOSS.HEATMAPS_LOSS_FACTOR, (list, tuple)), \ + 'LOSS.HEATMAPS_LOSS_FACTOR should be a list or tuple' + assert isinstance(cfg.LOSS.WITH_AE_LOSS, (list, tuple)), \ + 'LOSS.WITH_AE_LOSS should be a list or tuple' + assert isinstance(cfg.LOSS.PUSH_LOSS_FACTOR, (list, tuple)), \ + 'LOSS.PUSH_LOSS_FACTOR should be a list or tuple' + assert isinstance(cfg.LOSS.PUSH_LOSS_FACTOR, (list, tuple)), \ + 'LOSS.PUSH_LOSS_FACTOR should be a list or tuple' + assert len(cfg.LOSS.WITH_HEATMAPS_LOSS) == cfg.LOSS.NUM_STAGES, \ + 'LOSS.WITH_HEATMAPS_LOSS and LOSS.NUM_STAGE should have same length, got {} vs {}.'.\ + format(len(cfg.LOSS.WITH_HEATMAPS_LOSS), cfg.LOSS.NUM_STAGES) + assert len(cfg.LOSS.WITH_HEATMAPS_LOSS) == len(cfg.LOSS.HEATMAPS_LOSS_FACTOR), \ + 'LOSS.WITH_HEATMAPS_LOSS and LOSS.HEATMAPS_LOSS_FACTOR should have same length, got {} vs {}.'.\ + format(len(cfg.LOSS.WITH_HEATMAPS_LOSS), len(cfg.LOSS.HEATMAPS_LOSS_FACTOR)) + assert len(cfg.LOSS.WITH_AE_LOSS) == cfg.LOSS.NUM_STAGES, \ + 'LOSS.WITH_AE_LOSS and LOSS.NUM_STAGE should have same length, got {} vs {}.'.\ + format(len(cfg.LOSS.WITH_AE_LOSS), cfg.LOSS.NUM_STAGES) + assert len(cfg.LOSS.WITH_AE_LOSS) == len(cfg.LOSS.PUSH_LOSS_FACTOR), \ + 'LOSS.WITH_AE_LOSS and LOSS.PUSH_LOSS_FACTOR should have same length, got {} vs {}.'. \ + format(len(cfg.LOSS.WITH_AE_LOSS), len(cfg.LOSS.PUSH_LOSS_FACTOR)) + assert len(cfg.LOSS.WITH_AE_LOSS) == len(cfg.LOSS.PULL_LOSS_FACTOR), \ + 'LOSS.WITH_AE_LOSS and LOSS.PULL_LOSS_FACTOR should have same length, got {} vs {}.'. \ + format(len(cfg.LOSS.WITH_AE_LOSS), len(cfg.LOSS.PULL_LOSS_FACTOR)) + + def _forward_check(self, outputs, heatmaps, masks, joints): + assert isinstance(outputs, list), \ + 'outputs should be a list, got {} instead.'.format(type(outputs)) + assert isinstance(heatmaps, list), \ + 'heatmaps should be a list, got {} instead.'.format(type(heatmaps)) + assert isinstance(masks, list), \ + 'masks should be a list, got {} instead.'.format(type(masks)) + assert isinstance(joints, list), \ + 'joints should be a list, got {} instead.'.format(type(joints)) + assert len(outputs) == self.num_stages, \ + 'len(outputs) and num_stages should been same, got {} vs {}.'.format(len(outputs), self.num_stages) + assert len(outputs) == len(heatmaps), \ + 'outputs and heatmaps should have same length, got {} vs {}.'.format(len(outputs), len(heatmaps)) + assert len(outputs) == len(masks), \ + 'outputs and masks should have same length, got {} vs {}.'.format(len(outputs), len(masks)) + assert len(outputs) == len(joints), \ + 'outputs and joints should have same length, got {} vs {}.'.format(len(outputs), len(joints)) + assert len(outputs) == len(self.heatmaps_loss), \ + 'outputs and heatmaps_loss should have same length, got {} vs {}.'. \ + format(len(outputs), len(self.heatmaps_loss)) + assert len(outputs) == len(self.ae_loss), \ + 'outputs and ae_loss should have same length, got {} vs {}.'. \ + format(len(outputs), len(self.ae_loss)) + + +def test_ae_loss(): + import numpy as np + t = torch.tensor( + np.arange(0, 32).reshape(1, 2, 4, 4).astype(np.float)*0.1, + requires_grad=True + ) + t.register_hook(lambda x: print('t', x)) + + ae_loss = AELoss(loss_type='exp') + + joints = np.zeros((2, 2, 2)) + joints[0, 0] = (3, 1) + joints[1, 0] = (10, 1) + joints[0, 1] = (22, 1) + joints[1, 1] = (30, 1) + joints = torch.LongTensor(joints) + joints = joints.view(1, 2, 2, 2) + + t = t.contiguous().view(1, -1, 1) + l = ae_loss(t, joints) + + print(l) + + +if __name__ == '__main__': + test_ae_loss() diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/trainer.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/trainer.py new file mode 100644 index 0000000..0a20940 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/core/trainer.py @@ -0,0 +1,137 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import time + +from utils.utils import AverageMeter +from utils.vis import save_debug_images + + +def do_train(cfg, model, data_loader, loss_factory, optimizer, epoch, + output_dir, tb_log_dir, writer_dict, fp16=False): + logger = logging.getLogger("Training") + + batch_time = AverageMeter() + data_time = AverageMeter() + + heatmaps_loss_meter = [AverageMeter() for _ in range(cfg.LOSS.NUM_STAGES)] + push_loss_meter = [AverageMeter() for _ in range(cfg.LOSS.NUM_STAGES)] + pull_loss_meter = [AverageMeter() for _ in range(cfg.LOSS.NUM_STAGES)] + + # switch to train mode + model.train() + + end = time.time() + for i, (images, heatmaps, masks, joints) in enumerate(data_loader): + # measure data loading time + data_time.update(time.time() - end) + + # compute output + outputs = model(images) + + heatmaps = list(map(lambda x: x.cuda(non_blocking=True), heatmaps)) + masks = list(map(lambda x: x.cuda(non_blocking=True), masks)) + joints = list(map(lambda x: x.cuda(non_blocking=True), joints)) + + # loss = loss_factory(outputs, heatmaps, masks) + heatmaps_losses, push_losses, pull_losses = \ + loss_factory(outputs, heatmaps, masks, joints) + + loss = 0 + for idx in range(cfg.LOSS.NUM_STAGES): + if heatmaps_losses[idx] is not None: + heatmaps_loss = heatmaps_losses[idx].mean(dim=0) + heatmaps_loss_meter[idx].update( + heatmaps_loss.item(), images.size(0) + ) + loss = loss + heatmaps_loss + if push_losses[idx] is not None: + push_loss = push_losses[idx].mean(dim=0) + push_loss_meter[idx].update( + push_loss.item(), images.size(0) + ) + loss = loss + push_loss + if pull_losses[idx] is not None: + pull_loss = pull_losses[idx].mean(dim=0) + pull_loss_meter[idx].update( + pull_loss.item(), images.size(0) + ) + loss = loss + pull_loss + + # compute gradient and do update step + optimizer.zero_grad() + if fp16: + optimizer.backward(loss) + else: + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % cfg.PRINT_FREQ == 0 and cfg.RANK == 0: + msg = 'Epoch: [{0}][{1}/{2}]\t' \ + 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \ + 'Speed: {speed:.1f} samples/s\t' \ + 'Data: {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \ + '{heatmaps_loss}{push_loss}{pull_loss}'.format( + epoch, i, len(data_loader), + batch_time=batch_time, + speed=images.size(0)/batch_time.val, + data_time=data_time, + heatmaps_loss=_get_loss_info(heatmaps_loss_meter, 'heatmaps'), + push_loss=_get_loss_info(push_loss_meter, 'push'), + pull_loss=_get_loss_info(pull_loss_meter, 'pull') + ) + logger.info(msg) + + writer = writer_dict['writer'] + global_steps = writer_dict['train_global_steps'] + for idx in range(cfg.LOSS.NUM_STAGES): + writer.add_scalar( + 'train_stage{}_heatmaps_loss'.format(i), + heatmaps_loss_meter[idx].val, + global_steps + ) + writer.add_scalar( + 'train_stage{}_push_loss'.format(idx), + push_loss_meter[idx].val, + global_steps + ) + writer.add_scalar( + 'train_stage{}_pull_loss'.format(idx), + pull_loss_meter[idx].val, + global_steps + ) + writer_dict['train_global_steps'] = global_steps + 1 + + prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i) + for scale_idx in range(len(outputs)): + prefix_scale = prefix + '_output_{}'.format( + cfg.DATASET.OUTPUT_SIZE[scale_idx] + ) + save_debug_images( + cfg, images, heatmaps[scale_idx], masks[scale_idx], + outputs[scale_idx], prefix_scale + ) + + +def _get_loss_info(loss_meters, loss_name): + msg = '' + for i, meter in enumerate(loss_meters): + msg += 'Stage{i}-{name}: {meter.val:.3e} ({meter.avg:.3e})\t'.format( + i=i, name=loss_name, meter=meter + ) + + return msg diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCODataset.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCODataset.py new file mode 100644 index 0000000..265c54a --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCODataset.py @@ -0,0 +1,309 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import OrderedDict +import logging +import os +import os.path + +import cv2 +import json_tricks as json +import numpy as np +from torch.utils.data import Dataset + +from pycocotools.cocoeval import COCOeval +from utils import zipreader + +logger = logging.getLogger(__name__) + + +class CocoDataset(Dataset): + """`MS Coco Detection `_ Dataset. + + Args: + root (string): Root directory where dataset is located to. + dataset (string): Dataset name(train2017, val2017, test2017). + data_format(string): Data format for reading('jpg', 'zip') + transform (callable, optional): A function/transform that takes in an opencv image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, dataset, data_format, transform=None, + target_transform=None): + from pycocotools.coco import COCO + self.name = 'COCO' + self.root = root + self.dataset = dataset + self.data_format = data_format + self.coco = COCO(self._get_anno_file_name()) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + + cats = [cat['name'] + for cat in self.coco.loadCats(self.coco.getCatIds())] + self.classes = ['__background__'] + cats + logger.info('=> classes: {}'.format(self.classes)) + self.num_classes = len(self.classes) + self._class_to_ind = dict(zip(self.classes, range(self.num_classes))) + self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds())) + self._coco_ind_to_class_ind = dict( + [ + (self._class_to_coco_ind[cls], self._class_to_ind[cls]) + for cls in self.classes[1:] + ] + ) + + def _get_anno_file_name(self): + # example: root/annotations/person_keypoints_tran2017.json + # image_info_test-dev2017.json + if 'test' in self.dataset: + return os.path.join( + self.root, + 'annotations', + 'image_info_{}.json'.format( + self.dataset + ) + ) + else: + return os.path.join( + self.root, + 'annotations', + 'person_keypoints_{}.json'.format( + self.dataset + ) + ) + + def _get_image_path(self, file_name): + images_dir = os.path.join(self.root, 'images') + dataset = 'test2017' if 'test' in self.dataset else self.dataset + if self.data_format == 'zip': + return os.path.join(images_dir, dataset) + '.zip@' + file_name + else: + return os.path.join(images_dir, dataset, file_name) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + file_name = coco.loadImgs(img_id)[0]['file_name'] + + if self.data_format == 'zip': + img = zipreader.imread( + self._get_image_path(file_name), + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + else: + img = cv2.imread( + self._get_image_path(file_name), + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.ids) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + def processKeypoints(self, keypoints): + tmp = keypoints.copy() + if keypoints[:, 2].max() > 0: + p = keypoints[keypoints[:, 2] > 0][:, :2].mean(axis=0) + num_keypoints = keypoints.shape[0] + for i in range(num_keypoints): + tmp[i][0:3] = [ + float(keypoints[i][0]), + float(keypoints[i][1]), + float(keypoints[i][2]) + ] + + return tmp + + def evaluate(self, cfg, preds, scores, output_dir, + *args, **kwargs): + ''' + Perform evaluation on COCO keypoint task + :param cfg: cfg dictionary + :param preds: prediction + :param output_dir: output directory + :param args: + :param kwargs: + :return: + ''' + res_folder = os.path.join(output_dir, 'results') + if not os.path.exists(res_folder): + os.makedirs(res_folder) + res_file = os.path.join( + res_folder, 'keypoints_%s_results.json' % self.dataset) + + # preds is a list of: image x person x (keypoints) + # keypoints: num_joints * 4 (x, y, score, tag) + kpts = defaultdict(list) + for idx, _kpts in enumerate(preds): + img_id = self.ids[idx] + file_name = self.coco.loadImgs(img_id)[0]['file_name'] + for idx_kpt, kpt in enumerate(_kpts): + area = (np.max(kpt[:, 0]) - np.min(kpt[:, 0])) * (np.max(kpt[:, 1]) - np.min(kpt[:, 1])) + kpt = self.processKeypoints(kpt) + # if self.with_center: + if cfg.DATASET.WITH_CENTER and not cfg.TEST.IGNORE_CENTER: + kpt = kpt[:-1] + + kpts[int(file_name[-16:-4])].append( + { + 'keypoints': kpt[:, 0:3], + 'score': scores[idx][idx_kpt], + 'tags': kpt[:, 3], + 'image': int(file_name[-16:-4]), + 'area': area + } + ) + + # rescoring and oks nms + oks_nmsed_kpts = [] + # image x person x (keypoints) + for img in kpts.keys(): + # person x (keypoints) + img_kpts = kpts[img] + # person x (keypoints) + # do not use nms, keep all detections + keep = [] + if len(keep) == 0: + oks_nmsed_kpts.append(img_kpts) + else: + oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep]) + + self._write_coco_keypoint_results( + oks_nmsed_kpts, res_file + ) + + if 'test' not in self.dataset: + info_str = self._do_python_keypoint_eval( + res_file, res_folder + ) + name_value = OrderedDict(info_str) + return name_value, name_value['AP'] + else: + return {'Null': 0}, 0 + + def _write_coco_keypoint_results(self, keypoints, res_file): + data_pack = [ + { + 'cat_id': self._class_to_coco_ind[cls], + 'cls_ind': cls_ind, + 'cls': cls, + 'ann_type': 'keypoints', + 'keypoints': keypoints + } + for cls_ind, cls in enumerate(self.classes) if not cls == '__background__' + ] + + results = self._coco_keypoint_results_one_category_kernel(data_pack[0]) + logger.info('=> Writing results json to %s' % res_file) + with open(res_file, 'w') as f: + json.dump(results, f, sort_keys=True, indent=4) + try: + json.load(open(res_file)) + except Exception: + content = [] + with open(res_file, 'r') as f: + for line in f: + content.append(line) + content[-1] = ']' + with open(res_file, 'w') as f: + for c in content: + f.write(c) + + def _coco_keypoint_results_one_category_kernel(self, data_pack): + cat_id = data_pack['cat_id'] + keypoints = data_pack['keypoints'] + cat_results = [] + num_joints = 17 + + for img_kpts in keypoints: + if len(img_kpts) == 0: + continue + + _key_points = np.array( + [img_kpts[k]['keypoints'] for k in range(len(img_kpts))] + ) + key_points = np.zeros( + (_key_points.shape[0], num_joints * 3), + dtype=np.float + ) + + for ipt in range(num_joints): + key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0] + key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1] + key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score. + + for k in range(len(img_kpts)): + kpt = key_points[k].reshape((num_joints, 3)) + left_top = np.amin(kpt, axis=0) + right_bottom = np.amax(kpt, axis=0) + + w = right_bottom[0] - left_top[0] + h = right_bottom[1] - left_top[1] + + cat_results.append({ + 'image_id': img_kpts[k]['image'], + 'category_id': cat_id, + 'keypoints': list(key_points[k]), + 'score': img_kpts[k]['score'], + 'bbox': list([left_top[0], left_top[1], w, h]) + }) + + return cat_results + + def _do_python_keypoint_eval(self, res_file, res_folder): + coco_dt = self.coco.loadRes(res_file) + coco_eval = COCOeval(self.coco, coco_dt, 'keypoints') + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)'] + + info_str = [] + for ind, name in enumerate(stats_names): + info_str.append((name, coco_eval.stats[ind])) + # info_str.append(coco_eval.stats[ind]) + + return info_str diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCOKeypoints.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCOKeypoints.py new file mode 100644 index 0000000..9957720 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/COCOKeypoints.py @@ -0,0 +1,151 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +import pycocotools +from .COCODataset import CocoDataset +from .target_generators import HeatmapGenerator + + +logger = logging.getLogger(__name__) + + +class CocoKeypoints(CocoDataset): + def __init__(self, + cfg, + dataset_name, + remove_images_without_annotations, + heatmap_generator, + joints_generator, + transforms=None): + super().__init__(cfg.DATASET.ROOT, + dataset_name, + cfg.DATASET.DATA_FORMAT) + + if cfg.DATASET.WITH_CENTER: + assert cfg.DATASET.NUM_JOINTS == 18, 'Number of joint with center for COCO is 18' + else: + assert cfg.DATASET.NUM_JOINTS == 17, 'Number of joint for COCO is 17' + + self.num_scales = self._init_check(heatmap_generator, joints_generator) + + self.num_joints = cfg.DATASET.NUM_JOINTS + self.with_center = cfg.DATASET.WITH_CENTER + self.num_joints_without_center = self.num_joints - 1 \ + if self.with_center else self.num_joints + self.scale_aware_sigma = cfg.DATASET.SCALE_AWARE_SIGMA + self.base_sigma = cfg.DATASET.BASE_SIGMA + self.base_size = cfg.DATASET.BASE_SIZE + self.int_sigma = cfg.DATASET.INT_SIGMA + + if remove_images_without_annotations: + self.ids = [ + img_id + for img_id in self.ids + if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0 + ] + + self.transforms = transforms + self.heatmap_generator = heatmap_generator + self.joints_generator = joints_generator + + def __getitem__(self, idx): + img, anno = super().__getitem__(idx) + + mask = self.get_mask(anno, idx) + + anno = [ + obj for obj in anno + if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0 + ] + + # TODO(bowen): to generate scale-aware sigma, modify `get_joints` to associate a sigma to each joint + joints = self.get_joints(anno) + + mask_list = [mask.copy() for _ in range(self.num_scales)] + joints_list = [joints.copy() for _ in range(self.num_scales)] + target_list = list() + + if self.transforms: + img, mask_list, joints_list = self.transforms( + img, mask_list, joints_list + ) + + for scale_id in range(self.num_scales): + target_t = self.heatmap_generator[scale_id](joints_list[scale_id]) + joints_t = self.joints_generator[scale_id](joints_list[scale_id]) + + target_list.append(target_t.astype(np.float32)) + mask_list[scale_id] = mask_list[scale_id].astype(np.float32) + joints_list[scale_id] = joints_t.astype(np.int32) + + return img, target_list, mask_list, joints_list + + def get_joints(self, anno): + num_people = len(anno) + + if self.scale_aware_sigma: + joints = np.zeros((num_people, self.num_joints, 4)) + else: + joints = np.zeros((num_people, self.num_joints, 3)) + + for i, obj in enumerate(anno): + joints[i, :self.num_joints_without_center, :3] = \ + np.array(obj['keypoints']).reshape([-1, 3]) + if self.with_center: + joints_sum = np.sum(joints[i, :-1, :2], axis=0) + num_vis_joints = len(np.nonzero(joints[i, :-1, 2])[0]) + if num_vis_joints > 0: + joints[i, -1, :2] = joints_sum / num_vis_joints + joints[i, -1, 2] = 1 + if self.scale_aware_sigma: + # get person box + box = obj['bbox'] + size = max(box[2], box[3]) + sigma = size / self.base_size * self.base_sigma + if self.int_sigma: + sigma = int(np.round(sigma + 0.5)) + assert sigma > 0, sigma + joints[i, :, 3] = sigma + + return joints + + def get_mask(self, anno, idx): + coco = self.coco + img_info = coco.loadImgs(self.ids[idx])[0] + + m = np.zeros((img_info['height'], img_info['width'])) + + for obj in anno: + if obj['iscrowd']: + rle = pycocotools.mask.frPyObjects( + obj['segmentation'], img_info['height'], img_info['width']) + m += pycocotools.mask.decode(rle) + elif obj['num_keypoints'] == 0: + rles = pycocotools.mask.frPyObjects( + obj['segmentation'], img_info['height'], img_info['width']) + for rle in rles: + m += pycocotools.mask.decode(rle) + + return m < 0.5 + + def _init_check(self, heatmap_generator, joints_generator): + assert isinstance(heatmap_generator, (list, tuple)), 'heatmap_generator should be a list or tuple' + assert isinstance(joints_generator, (list, tuple)), 'joints_generator should be a list or tuple' + assert len(heatmap_generator) == len(joints_generator), \ + 'heatmap_generator and joints_generator should have same length,'\ + 'got {} vs {}.'.format( + len(heatmap_generator), len(joints_generator) + ) + return len(heatmap_generator) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseDataset.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseDataset.py new file mode 100644 index 0000000..3329aaf --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseDataset.py @@ -0,0 +1,296 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bowen Cheng (bcheng9@illinois.edu) and Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import OrderedDict +import logging +import os +import os.path + +import cv2 +import json_tricks as json +import numpy as np +from torch.utils.data import Dataset + +from crowdposetools.cocoeval import COCOeval +from utils import zipreader + +logger = logging.getLogger(__name__) + + +class CrowdPoseDataset(Dataset): + """`CrowdPose`_ Dataset. + + Args: + root (string): Root directory where dataset is located to. + dataset (string): Dataset name(train2017, val2017, test2017). + data_format(string): Data format for reading('jpg', 'zip') + transform (callable, optional): A function/transform that takes in an opencv image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, dataset, data_format, transform=None, + target_transform=None): + from crowdposetools.coco import COCO + self.name = 'CROWDPOSE' + self.root = root + self.dataset = dataset + self.data_format = data_format + self.coco = COCO(self._get_anno_file_name()) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + + cats = [cat['name'] + for cat in self.coco.loadCats(self.coco.getCatIds())] + self.classes = ['__background__'] + cats + logger.info('=> classes: {}'.format(self.classes)) + self.num_classes = len(self.classes) + self._class_to_ind = dict(zip(self.classes, range(self.num_classes))) + self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds())) + self._coco_ind_to_class_ind = dict( + [ + (self._class_to_coco_ind[cls], self._class_to_ind[cls]) + for cls in self.classes[1:] + ] + ) + + def _get_anno_file_name(self): + # example: root/json/crowdpose_{train,val,test}.json + return os.path.join( + self.root, + 'json', + 'crowdpose_{}.json'.format( + self.dataset + ) + ) + + def _get_image_path(self, file_name): + images_dir = os.path.join(self.root, 'images') + if self.data_format == 'zip': + return images_dir + '.zip@' + file_name + else: + return os.path.join(images_dir, file_name) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + file_name = coco.loadImgs(img_id)[0]['file_name'] + + if self.data_format == 'zip': + img = zipreader.imread( + self._get_image_path(file_name), + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + else: + img = cv2.imread( + self._get_image_path(file_name), + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.ids) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + def processKeypoints(self, keypoints): + tmp = keypoints.copy() + if keypoints[:, 2].max() > 0: + p = keypoints[keypoints[:, 2] > 0][:, :2].mean(axis=0) + num_keypoints = keypoints.shape[0] + for i in range(num_keypoints): + tmp[i][0:3] = [ + float(keypoints[i][0]), + float(keypoints[i][1]), + float(keypoints[i][2]) + ] + + return tmp + + def evaluate(self, cfg, preds, scores, output_dir, + *args, **kwargs): + ''' + Perform evaluation on COCO keypoint task + :param cfg: cfg dictionary + :param preds: prediction + :param output_dir: output directory + :param args: + :param kwargs: + :return: + ''' + res_folder = os.path.join(output_dir, 'results') + if not os.path.exists(res_folder): + os.makedirs(res_folder) + res_file = os.path.join( + res_folder, 'keypoints_%s_results.json' % self.dataset) + + # preds is a list of: image x person x (keypoints) + # keypoints: num_joints * 4 (x, y, score, tag) + kpts = defaultdict(list) + for idx, _kpts in enumerate(preds): + img_id = self.ids[idx] + file_name = self.coco.loadImgs(img_id)[0]['file_name'] + for idx_kpt, kpt in enumerate(_kpts): + area = (np.max(kpt[:, 0]) - np.min(kpt[:, 0])) * (np.max(kpt[:, 1]) - np.min(kpt[:, 1])) + kpt = self.processKeypoints(kpt) + # if self.with_center: + if cfg.DATASET.WITH_CENTER and not cfg.TEST.IGNORE_CENTER: + kpt = kpt[:-1] + + kpts[int(file_name.split('.')[0])].append( + { + 'keypoints': kpt[:, 0:3], + 'score': scores[idx][idx_kpt], + 'tags': kpt[:, 3], + 'image': int(file_name.split('.')[0]), + 'area': area + } + ) + + # rescoring and oks nms + oks_nmsed_kpts = [] + # image x person x (keypoints) + for img in kpts.keys(): + # person x (keypoints) + img_kpts = kpts[img] + # person x (keypoints) + # do not use nms, keep all detections + keep = [] + if len(keep) == 0: + oks_nmsed_kpts.append(img_kpts) + else: + oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep]) + + self._write_coco_keypoint_results( + oks_nmsed_kpts, res_file + ) + + # CrowdPose `test` set has annotation. + info_str = self._do_python_keypoint_eval( + res_file, res_folder + ) + name_value = OrderedDict(info_str) + return name_value, name_value['AP'] + + def _write_coco_keypoint_results(self, keypoints, res_file): + data_pack = [ + { + 'cat_id': self._class_to_coco_ind[cls], + 'cls_ind': cls_ind, + 'cls': cls, + 'ann_type': 'keypoints', + 'keypoints': keypoints + } + for cls_ind, cls in enumerate(self.classes) if not cls == '__background__' + ] + + results = self._coco_keypoint_results_one_category_kernel(data_pack[0]) + logger.info('=> Writing results json to %s' % res_file) + with open(res_file, 'w') as f: + json.dump(results, f, sort_keys=True, indent=4) + try: + json.load(open(res_file)) + except Exception: + content = [] + with open(res_file, 'r') as f: + for line in f: + content.append(line) + content[-1] = ']' + with open(res_file, 'w') as f: + for c in content: + f.write(c) + + def _coco_keypoint_results_one_category_kernel(self, data_pack): + cat_id = data_pack['cat_id'] + keypoints = data_pack['keypoints'] + cat_results = [] + num_joints = 14 + + for img_kpts in keypoints: + if len(img_kpts) == 0: + continue + + _key_points = np.array( + [img_kpts[k]['keypoints'] for k in range(len(img_kpts))] + ) + key_points = np.zeros( + (_key_points.shape[0], num_joints * 3), + dtype=np.float + ) + + for ipt in range(num_joints): + key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0] + key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1] + key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score. + + for k in range(len(img_kpts)): + kpt = key_points[k].reshape((num_joints, 3)) + left_top = np.amin(kpt, axis=0) + right_bottom = np.amax(kpt, axis=0) + + w = right_bottom[0] - left_top[0] + h = right_bottom[1] - left_top[1] + + cat_results.append({ + 'image_id': img_kpts[k]['image'], + 'category_id': cat_id, + 'keypoints': list(key_points[k]), + 'score': img_kpts[k]['score'], + 'bbox': list([left_top[0], left_top[1], w, h]) + }) + + return cat_results + + def _do_python_keypoint_eval(self, res_file, res_folder): + coco_dt = self.coco.loadRes(res_file) + coco_eval = COCOeval(self.coco, coco_dt, 'keypoints') + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats_names = ['AP', 'Ap .5', 'AP .75', 'AR', 'AR .5', 'AR .75', 'AP (easy)', 'AP (medium)', 'AP (hard)'] + stats_index = [0, 1, 2, 5, 6, 7, 8, 9, 10] + + info_str = [] + for ind, name in enumerate(stats_names): + info_str.append((name, coco_eval.stats[stats_index[ind]])) + # info_str.append(coco_eval.stats[ind]) + + return info_str diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseKeypoints.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseKeypoints.py new file mode 100644 index 0000000..120b763 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/CrowdPoseKeypoints.py @@ -0,0 +1,139 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bowen Cheng (bcheng9@illinois.edu) and Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +import crowdposetools +from .CrowdPoseDataset import CrowdPoseDataset +from .target_generators import HeatmapGenerator + + +logger = logging.getLogger(__name__) + + +class CrowdPoseKeypoints(CrowdPoseDataset): + def __init__(self, + cfg, + dataset_name, + remove_images_without_annotations, + heatmap_generator, + joints_generator, + transforms=None): + super().__init__(cfg.DATASET.ROOT, + dataset_name, + cfg.DATASET.DATA_FORMAT) + + if cfg.DATASET.WITH_CENTER: + assert cfg.DATASET.NUM_JOINTS == 15, 'Number of joint with center for CrowdPose is 15' + else: + assert cfg.DATASET.NUM_JOINTS == 14, 'Number of joint for CrowdPose is 14' + + self.num_scales = self._init_check(heatmap_generator, joints_generator) + + self.num_joints = cfg.DATASET.NUM_JOINTS + self.with_center = cfg.DATASET.WITH_CENTER + self.num_joints_without_center = self.num_joints - 1 \ + if self.with_center else self.num_joints + self.scale_aware_sigma = cfg.DATASET.SCALE_AWARE_SIGMA + self.base_sigma = cfg.DATASET.BASE_SIGMA + self.base_size = cfg.DATASET.BASE_SIZE + self.int_sigma = cfg.DATASET.INT_SIGMA + + if remove_images_without_annotations: + self.ids = [ + img_id + for img_id in self.ids + if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0 + ] + + self.transforms = transforms + self.heatmap_generator = heatmap_generator + self.joints_generator = joints_generator + + def __getitem__(self, idx): + img, anno = super().__getitem__(idx) + + mask = self.get_mask(anno, idx) + + anno = [ + obj for obj in anno + if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0 + ] + + # TODO(bowen): to generate scale-aware sigma, modify `get_joints` to associate a sigma to each joint + joints = self.get_joints(anno) + + mask_list = [mask.copy() for _ in range(self.num_scales)] + joints_list = [joints.copy() for _ in range(self.num_scales)] + target_list = list() + + if self.transforms: + img, mask_list, joints_list = self.transforms( + img, mask_list, joints_list + ) + + for scale_id in range(self.num_scales): + target_t = self.heatmap_generator[scale_id](joints_list[scale_id]) + joints_t = self.joints_generator[scale_id](joints_list[scale_id]) + + target_list.append(target_t.astype(np.float32)) + mask_list[scale_id] = mask_list[scale_id].astype(np.float32) + joints_list[scale_id] = joints_t.astype(np.int32) + + return img, target_list, mask_list, joints_list + + def get_joints(self, anno): + num_people = len(anno) + + if self.scale_aware_sigma: + joints = np.zeros((num_people, self.num_joints, 4)) + else: + joints = np.zeros((num_people, self.num_joints, 3)) + + for i, obj in enumerate(anno): + joints[i, :self.num_joints_without_center, :3] = \ + np.array(obj['keypoints']).reshape([-1, 3]) + if self.with_center: + joints_sum = np.sum(joints[i, :-1, :2], axis=0) + num_vis_joints = len(np.nonzero(joints[i, :-1, 2])[0]) + if num_vis_joints > 0: + joints[i, -1, :2] = joints_sum / num_vis_joints + joints[i, -1, 2] = 1 + if self.scale_aware_sigma: + # get person box + box = obj['bbox'] + size = max(box[2], box[3]) + sigma = size / self.base_size * self.base_sigma + if self.int_sigma: + sigma = int(np.round(sigma + 0.5)) + assert sigma > 0, sigma + joints[i, :, 3] = sigma + + return joints + + def get_mask(self, anno, idx): + coco = self.coco + img_info = coco.loadImgs(self.ids[idx])[0] + + m = np.zeros((img_info['height'], img_info['width'])) + + return m < 0.5 + + def _init_check(self, heatmap_generator, joints_generator): + assert isinstance(heatmap_generator, (list, tuple)), 'heatmap_generator should be a list or tuple' + assert isinstance(joints_generator, (list, tuple)), 'joints_generator should be a list or tuple' + assert len(heatmap_generator) == len(joints_generator), \ + 'heatmap_generator and joints_generator should have same length,'\ + 'got {} vs {}.'.format( + len(heatmap_generator), len(joints_generator) + ) + return len(heatmap_generator) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/__init__.py new file mode 100644 index 0000000..29af8b2 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/__init__.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from .COCOKeypoints import CocoKeypoints as coco +#from .CrowdPoseKeypoints import CrowdPoseKeypoints as crowd_pose +from .build import make_dataloader +from .build import make_test_dataloader + +# dataset dependent configuration for visualization +coco_part_labels = [ + 'nose', 'eye_l', 'eye_r', 'ear_l', 'ear_r', + 'sho_l', 'sho_r', 'elb_l', 'elb_r', 'wri_l', 'wri_r', + 'hip_l', 'hip_r', 'kne_l', 'kne_r', 'ank_l', 'ank_r' +] +coco_part_idx = { + b: a for a, b in enumerate(coco_part_labels) +} +coco_part_orders = [ + ('nose', 'eye_l'), ('eye_l', 'eye_r'), ('eye_r', 'nose'), + ('eye_l', 'ear_l'), ('eye_r', 'ear_r'), ('ear_l', 'sho_l'), + ('ear_r', 'sho_r'), ('sho_l', 'sho_r'), ('sho_l', 'hip_l'), + ('sho_r', 'hip_r'), ('hip_l', 'hip_r'), ('sho_l', 'elb_l'), + ('elb_l', 'wri_l'), ('sho_r', 'elb_r'), ('elb_r', 'wri_r'), + ('hip_l', 'kne_l'), ('kne_l', 'ank_l'), ('hip_r', 'kne_r'), + ('kne_r', 'ank_r') +] + +crowd_pose_part_labels = [ + 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', + 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', + 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', + 'head', 'neck' +] +crowd_pose_part_idx = { + b: a for a, b in enumerate(crowd_pose_part_labels) +} +crowd_pose_part_orders = [ + ('head', 'neck'), ('neck', 'left_shoulder'), ('neck', 'right_shoulder'), + ('left_shoulder', 'right_shoulder'), ('left_shoulder', 'left_hip'), + ('right_shoulder', 'right_hip'), ('left_hip', 'right_hip'), ('left_shoulder', 'left_elbow'), + ('left_elbow', 'left_wrist'), ('right_shoulder', 'right_elbow'), ('right_elbow', 'right_wrist'), + ('left_hip', 'left_knee'), ('left_knee', 'left_ankle'), ('right_hip', 'right_knee'), + ('right_knee', 'right_ankle') +] + +VIS_CONFIG = { + 'COCO': { + 'part_labels': coco_part_labels, + 'part_idx': coco_part_idx, + 'part_orders': coco_part_orders + }, + 'CROWDPOSE': { + 'part_labels': crowd_pose_part_labels, + 'part_idx': crowd_pose_part_idx, + 'part_orders': crowd_pose_part_orders + } +} diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/build.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/build.py new file mode 100644 index 0000000..95be9eb --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/build.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.utils.data + +# from .COCODataset import CocoDataset as coco +# from .COCOKeypoints import CocoKeypoints as coco_kpt +# from .CrowdPoseDataset import CrowdPoseDataset as crowd_pose +# from .CrowdPoseKeypoints import CrowdPoseKeypoints as crowd_pose_kpt +from .transforms import build_transforms +from .target_generators import HeatmapGenerator +from .target_generators import ScaleAwareHeatmapGenerator +from .target_generators import JointsGenerator + + +def build_dataset(cfg, is_train): + transforms = build_transforms(cfg, is_train) + + if cfg.DATASET.SCALE_AWARE_SIGMA: + _HeatmapGenerator = ScaleAwareHeatmapGenerator + else: + _HeatmapGenerator = HeatmapGenerator + + heatmap_generator = [ + _HeatmapGenerator( + output_size, cfg.DATASET.NUM_JOINTS, cfg.DATASET.SIGMA + ) for output_size in cfg.DATASET.OUTPUT_SIZE + ] + joints_generator = [ + JointsGenerator( + cfg.DATASET.MAX_NUM_PEOPLE, + cfg.DATASET.NUM_JOINTS, + output_size, + cfg.MODEL.TAG_PER_JOINT + ) for output_size in cfg.DATASET.OUTPUT_SIZE + ] + + dataset_name = cfg.DATASET.TRAIN if is_train else cfg.DATASET.TEST + + dataset = eval(cfg.DATASET.DATASET)( + cfg, + dataset_name, + is_train, + heatmap_generator, + joints_generator, + transforms + ) + + return dataset + + +def make_dataloader(cfg, is_train=True, distributed=False): + if is_train: + images_per_gpu = cfg.TRAIN.IMAGES_PER_GPU + shuffle = True + else: + images_per_gpu = cfg.TEST.IMAGES_PER_GPU + shuffle = False + images_per_batch = images_per_gpu * len(cfg.GPUS) + + dataset = build_dataset(cfg, is_train) + + if is_train and distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + dataset + ) + shuffle = False + else: + train_sampler = None + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=images_per_batch, + shuffle=shuffle, + num_workers=cfg.WORKERS, + pin_memory=cfg.PIN_MEMORY, + sampler=train_sampler + ) + + return data_loader + + +def make_test_dataloader(cfg): + transforms = None + dataset = eval(cfg.DATASET.DATASET_TEST)( + cfg.DATASET.ROOT, + cfg.DATASET.TEST, + cfg.DATASET.DATA_FORMAT, + transforms + ) + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=False + ) + + return data_loader, dataset diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/__init__.py new file mode 100644 index 0000000..323fcb0 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/__init__.py @@ -0,0 +1,5 @@ +from .target_generators import HeatmapGenerator +from .target_generators import ScaleAwareHeatmapGenerator +from .target_generators import JointsGenerator + +__all__ = ['HeatmapGenerator', 'ScaleAwareHeatmapGenerator', 'JointsGenerator'] diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/target_generators.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/target_generators.py new file mode 100644 index 0000000..e8e3165 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/target_generators/target_generators.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +class HeatmapGenerator(): + def __init__(self, output_res, num_joints, sigma=-1): + self.output_res = output_res + self.num_joints = num_joints + if sigma < 0: + sigma = self.output_res/64 + self.sigma = sigma + size = 6*sigma + 3 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0, y0 = 3*sigma + 1, 3*sigma + 1 + self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + + def __call__(self, joints): + hms = np.zeros((self.num_joints, self.output_res, self.output_res), + dtype=np.float32) + sigma = self.sigma + for p in joints: + for idx, pt in enumerate(p): + if pt[2] > 0: + x, y = int(pt[0]), int(pt[1]) + if x < 0 or y < 0 or \ + x >= self.output_res or y >= self.output_res: + continue + + ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1)) + br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2)) + + c, d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0] + a, b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1] + + cc, dd = max(0, ul[0]), min(br[0], self.output_res) + aa, bb = max(0, ul[1]), min(br[1], self.output_res) + hms[idx, aa:bb, cc:dd] = np.maximum( + hms[idx, aa:bb, cc:dd], self.g[a:b, c:d]) + return hms + + +class ScaleAwareHeatmapGenerator(): + def __init__(self, output_res, num_joints): + self.output_res = output_res + self.num_joints = num_joints + + def get_gaussian_kernel(self, sigma): + size = 6*sigma + 3 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0, y0 = 3*sigma + 1, 3*sigma + 1 + g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + return g + + def __call__(self, joints): + hms = np.zeros((self.num_joints, self.output_res, self.output_res), + dtype=np.float32) + for p in joints: + sigma = p[0, 3] + g = self.get_gaussian_kernel(sigma) + for idx, pt in enumerate(p): + if pt[2] > 0: + x, y = int(pt[0]), int(pt[1]) + if x < 0 or y < 0 or \ + x >= self.output_res or y >= self.output_res: + continue + + ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1)) + br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2)) + + c, d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0] + a, b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1] + + cc, dd = max(0, ul[0]), min(br[0], self.output_res) + aa, bb = max(0, ul[1]), min(br[1], self.output_res) + hms[idx, aa:bb, cc:dd] = np.maximum( + hms[idx, aa:bb, cc:dd], g[a:b, c:d]) + return hms + + +class JointsGenerator(): + def __init__(self, max_num_people, num_joints, output_res, tag_per_joint): + self.max_num_people = max_num_people + self.num_joints = num_joints + self.output_res = output_res + self.tag_per_joint = tag_per_joint + + def __call__(self, joints): + visible_nodes = np.zeros((self.max_num_people, self.num_joints, 2)) + output_res = self.output_res + for i in range(len(joints)): + tot = 0 + for idx, pt in enumerate(joints[i]): + x, y = int(pt[0]), int(pt[1]) + if pt[2] > 0 and x >= 0 and y >= 0 \ + and x < self.output_res and y < self.output_res: + if self.tag_per_joint: + visible_nodes[i][tot] = \ + (idx * output_res**2 + y * output_res + x, 1) + else: + visible_nodes[i][tot] = \ + (y * output_res + x, 1) + tot += 1 + return visible_nodes diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/__init__.py new file mode 100644 index 0000000..f8074b2 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/__init__.py @@ -0,0 +1,8 @@ +from .transforms import Compose +from .transforms import RandomAffineTransform +from .transforms import ToTensor +from .transforms import Normalize +from .transforms import RandomHorizontalFlip + +from .build import build_transforms +from .build import FLIP_CONFIG diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/build.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/build.py new file mode 100644 index 0000000..6b18320 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/build.py @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import transforms as T + + +FLIP_CONFIG = { + 'COCO': [ + 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15 + ], + 'COCO_WITH_CENTER': [ + 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17 + ], + 'CROWDPOSE': [ + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 12, 13 + ], + 'CROWDPOSE_WITH_CENTER': [ + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 12, 13, 14 + ] +} + + +def build_transforms(cfg, is_train=True): + assert is_train is True, 'Please only use build_transforms for training.' + assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple' + if is_train: + max_rotation = cfg.DATASET.MAX_ROTATION + min_scale = cfg.DATASET.MIN_SCALE + max_scale = cfg.DATASET.MAX_SCALE + max_translate = cfg.DATASET.MAX_TRANSLATE + input_size = cfg.DATASET.INPUT_SIZE + output_size = cfg.DATASET.OUTPUT_SIZE + flip = cfg.DATASET.FLIP + scale_type = cfg.DATASET.SCALE_TYPE + else: + scale_type = cfg.DATASET.SCALE_TYPE + max_rotation = 0 + min_scale = 1 + max_scale = 1 + max_translate = 0 + input_size = 512 + output_size = [128] + flip = 0 + + # coco_flip_index = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + # if cfg.DATASET.WITH_CENTER: + # coco_flip_index.append(17) + if 'coco' in cfg.DATASET.DATASET: + dataset_name = 'COCO' + elif 'crowd_pose' in cfg.DATASET.DATASET: + dataset_name = 'CROWDPOSE' + else: + raise ValueError('Please implement flip_index for new dataset: %s.' % cfg.DATASET.DATASET) + if cfg.DATASET.WITH_CENTER: + coco_flip_index = FLIP_CONFIG[dataset_name + '_WITH_CENTER'] + else: + coco_flip_index = FLIP_CONFIG[dataset_name] + + transforms = T.Compose( + [ + T.RandomAffineTransform( + input_size, + output_size, + max_rotation, + min_scale, + max_scale, + scale_type, + max_translate, + scale_aware_sigma=cfg.DATASET.SCALE_AWARE_SIGMA + ), + T.RandomHorizontalFlip(coco_flip_index, output_size, flip), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ] + ) + + return transforms diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/transforms.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/transforms.py new file mode 100644 index 0000000..0be0ecc --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/dataset/transforms/transforms.py @@ -0,0 +1,182 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import cv2 +import numpy as np +import torch +import torchvision +from torchvision.transforms import functional as F + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, mask, joints): + for t in self.transforms: + image, mask, joints = t(image, mask, joints) + return image, mask, joints + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class ToTensor(object): + def __call__(self, image, mask, joints): + return F.to_tensor(image), mask, joints + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, mask, joints): + image = F.normalize(image, mean=self.mean, std=self.std) + return image, mask, joints + + +class RandomHorizontalFlip(object): + def __init__(self, flip_index, output_size, prob=0.5): + self.flip_index = flip_index + self.prob = prob + self.output_size = output_size if isinstance(output_size, list) \ + else [output_size] + + def __call__(self, image, mask, joints): + assert isinstance(mask, list) + assert isinstance(joints, list) + assert len(mask) == len(joints) + assert len(mask) == len(self.output_size) + + if random.random() < self.prob: + image = image[:, ::-1] - np.zeros_like(image) + for i, _output_size in enumerate(self.output_size): + mask[i] = mask[i][:, ::-1] - np.zeros_like(mask[i]) + joints[i] = joints[i][:, self.flip_index] + joints[i][:, :, 0] = _output_size - joints[i][:, :, 0] - 1 + + return image, mask, joints + + +class RandomAffineTransform(object): + def __init__(self, + input_size, + output_size, + max_rotation, + min_scale, + max_scale, + scale_type, + max_translate, + scale_aware_sigma=False): + self.input_size = input_size + self.output_size = output_size if isinstance(output_size, list) \ + else [output_size] + + self.max_rotation = max_rotation + self.min_scale = min_scale + self.max_scale = max_scale + self.scale_type = scale_type + self.max_translate = max_translate + self.scale_aware_sigma = scale_aware_sigma + + def _get_affine_matrix(self, center, scale, res, rot=0): + # Generate transformation matrix + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1]/2 + t_mat[1, 2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + def _affine_joints(self, joints, mat): + joints = np.array(joints) + shape = joints.shape + joints = joints.reshape(-1, 2) + return np.dot(np.concatenate( + (joints, joints[:, 0:1]*0+1), axis=1), mat.T).reshape(shape) + + def __call__(self, image, mask, joints): + assert isinstance(mask, list) + assert isinstance(joints, list) + assert len(mask) == len(joints) + assert len(mask) == len(self.output_size) + + height, width = image.shape[:2] + + center = np.array((width/2, height/2)) + if self.scale_type == 'long': + scale = max(height, width)/200 + elif self.scale_type == 'short': + scale = min(height, width)/200 + else: + raise ValueError('Unkonw scale type: {}'.format(self.scale_type)) + aug_scale = np.random.random() * (self.max_scale - self.min_scale) \ + + self.min_scale + scale *= aug_scale + aug_rot = (np.random.random() * 2 - 1) * self.max_rotation + + if self.max_translate > 0: + dx = np.random.randint( + -self.max_translate*scale, self.max_translate*scale) + dy = np.random.randint( + -self.max_translate*scale, self.max_translate*scale) + center[0] += dx + center[1] += dy + + for i, _output_size in enumerate(self.output_size): + mat_output = self._get_affine_matrix( + center, scale, (_output_size, _output_size), aug_rot + )[:2] + mask[i] = cv2.warpAffine( + (mask[i]*255).astype(np.uint8), mat_output, + (_output_size, _output_size) + ) / 255 + mask[i] = (mask[i] > 0.5).astype(np.float32) + + joints[i][:, :, 0:2] = self._affine_joints( + joints[i][:, :, 0:2], mat_output + ) + if self.scale_aware_sigma: + joints[i][:, :, 3] = joints[i][:, :, 3] / aug_scale + + mat_input = self._get_affine_matrix( + center, scale, (self.input_size, self.input_size), aug_rot + )[:2] + image = cv2.warpAffine( + image, mat_input, (self.input_size, self.input_size) + ) + + return image, mask, joints diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16_optimizer.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16_optimizer.py new file mode 100644 index 0000000..785aef0 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16_optimizer.py @@ -0,0 +1,540 @@ +# ------------------------------------------------------------------------------ +# Based on: +# apex +# Copyright (c) NVIDIA +# Licence under The BSD 3-Clause "New" or "Revised" License +# https://github.com/NVIDIA/apex +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +# following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +# disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Modified by Bowen Cheng +# ------------------------------------------------------------------------------ + +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from .loss_scaler import DynamicLossScaler, LossScaler +from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm + + +# TODO: Update overflow check + downscale to use Carl's fused kernel. +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + Example with dynamic loss scaling:: + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + Calls to :: + loss.backward() + must be replaced with :: + optimizer.backward(loss) + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + **Advanced options** + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + **Gradient clipping**: Use :attr:`clip_master_grads`. + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True): + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print("FP16_Optimizer processing param group {}:".format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}" + .format(param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[master_param] = self.optimizer.state.pop(param) + elif param.type() == 'torch.cuda.FloatTensor': + self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}" + .format(param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError("Wrapped parameters must be either " + "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + "Received {}".format(param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = clip_grad_norm + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") + + def __setstate__(self, state): + raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_() # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1. / self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. + Example with closure:: + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + # print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" + # .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while (self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + # print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " + # "reducing to {}".format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True): + """ + :attr:`backward` performs the following conceptual steps: + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. + 5. Finally, master grads are divided by loss_scale. + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. + Example:: + # Ordinary operation: + optimizer.backward(loss) + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float()) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + raise NotImplementedError("Currently not implemented, working on it...") + fp32_grads_each_group = [] + if self.overflow: + print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " + "Gradients are currently invalid (may be inf, nan, or stale). Returning None.") + return None + else: + return None + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16util.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16util.py new file mode 100644 index 0000000..6010aa1 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/fp16util.py @@ -0,0 +1,188 @@ +# ------------------------------------------------------------------------------ +# Based on: +# apex +# Copyright (c) NVIDIA +# Licence under The BSD 3-Clause "New" or "Revised" License +# https://github.com/NVIDIA/apex +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +# following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +# disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Modified by Bowen Cheng +# ------------------------------------------------------------------------------ + +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class tofp16(nn.Module): + """ + Model wrapper that implements:: + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +class tofp32(nn.Module): + """ + Model wrapper that implements:: + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp32, self).__init__() + + def forward(self, input): + if isinstance(input, list): + return list(map(lambda x: x.float(), input)) + else: + return input.float() + + +def BN_convert_float(module): + ''' + Designed to work with network_to_half. + BatchNorm layers need parameters in single precision. + Find all layers and convert them back to float. This can't + be done with built in .apply as that function will apply + fn to all modules, parameters, and buffers. Thus we wouldn't + be able to guard the float conversion based on the module type. + ''' + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half()), tofp32()) + + +def backwards_debug_hook(grad): + raise RuntimeError("master_params recieved a gradient in the backward pass!") + + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. + Example:: + model_params, master_params = prep_param_lists(model) + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [param for param in model.parameters() if param.requires_grad] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors([param.data for param in model_params]).float() + except: + print("Error in prep_param_lists: model may contain a mixture of parameters " + "of different types. Use flat_master=False, or use F16_Optimizer.") + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [param.clone().float().detach() for param in model_params] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, master_params, flat_master=False): + """ + Copy model gradients to master gradients. + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable(master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, master_params, flat_master=False): + """ + Copy master parameters to model parameters. + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. + """ + if flat_master: + for model, master in zip(model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + + +# Backward compatibility fixes +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: + clip_grad_norm = torch.nn.utils.clip_grad_norm +else: + clip_grad_norm = torch.nn.utils.clip_grad_norm_ diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/loss_scaler.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/loss_scaler.py new file mode 100644 index 0000000..e512af6 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/fp16_utils/loss_scaler.py @@ -0,0 +1,212 @@ +# ------------------------------------------------------------------------------ +# Based on: +# apex +# Copyright (c) NVIDIA +# Licence under The BSD 3-Clause "New" or "Revised" License +# https://github.com/NVIDIA/apex +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +# following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +# disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Modified by Bowen Cheng +# ------------------------------------------------------------------------------ + +import torch + + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss * self.loss_scale + scaled_loss.backward() + + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. + """ + + def __init__(self, + init_scale=2 ** 32, + scale_factor=2., + scale_window=1000): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + for p in params: + # if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): + # return True + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + # `x` is a torch.Tensor + # def _has_inf_or_nan(x): + def _has_inf_or_nan(self, x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + if overflow: + # self.cur_scale /= self.scale_factor + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss * self.loss_scale + scaled_loss.backward() + + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + # Run backprop + optimizer.zero_grad() + loss.backward() + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) +""" diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/__init__.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/__init__.py new file mode 100644 index 0000000..11df676 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/__init__.py @@ -0,0 +1,11 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import models.pose_higher_hrnet diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/pose_higher_hrnet.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/pose_higher_hrnet.py new file mode 100644 index 0000000..cd6255a --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/models/pose_higher_hrnet.py @@ -0,0 +1,570 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging + +import torch +import torch.nn as nn + + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class PoseHigherResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + self.inplanes = 64 + extra = cfg.MODEL.EXTRA + super(PoseHigherResolutionNet, self).__init__() + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, 64, 4) + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=False) + + self.final_layers = self._make_final_layers(cfg, pre_stage_channels[0]) + self.deconv_layers = self._make_deconv_layers( + cfg, pre_stage_channels[0]) + + self.num_deconvs = extra.DECONV.NUM_DECONVS + self.deconv_config = cfg.MODEL.EXTRA.DECONV + self.loss_config = cfg.LOSS + + self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] + + def _make_final_layers(self, cfg, input_channels): + dim_tag = cfg.MODEL.NUM_JOINTS if cfg.MODEL.TAG_PER_JOINT else 1 + extra = cfg.MODEL.EXTRA + + final_layers = [] + output_channels = cfg.MODEL.NUM_JOINTS + dim_tag \ + if cfg.LOSS.WITH_AE_LOSS[0] else cfg.MODEL.NUM_JOINTS + final_layers.append(nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=extra.FINAL_CONV_KERNEL, + stride=1, + padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0 + )) + + deconv_cfg = extra.DECONV + for i in range(deconv_cfg.NUM_DECONVS): + input_channels = deconv_cfg.NUM_CHANNELS[i] + output_channels = cfg.MODEL.NUM_JOINTS + dim_tag \ + if cfg.LOSS.WITH_AE_LOSS[i+1] else cfg.MODEL.NUM_JOINTS + final_layers.append(nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=extra.FINAL_CONV_KERNEL, + stride=1, + padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0 + )) + + return nn.ModuleList(final_layers) + + def _make_deconv_layers(self, cfg, input_channels): + dim_tag = cfg.MODEL.NUM_JOINTS if cfg.MODEL.TAG_PER_JOINT else 1 + extra = cfg.MODEL.EXTRA + deconv_cfg = extra.DECONV + + deconv_layers = [] + for i in range(deconv_cfg.NUM_DECONVS): + if deconv_cfg.CAT_OUTPUT[i]: + final_output_channels = cfg.MODEL.NUM_JOINTS + dim_tag \ + if cfg.LOSS.WITH_AE_LOSS[i] else cfg.MODEL.NUM_JOINTS + input_channels += final_output_channels + output_channels = deconv_cfg.NUM_CHANNELS[i] + deconv_kernel, padding, output_padding = \ + self._get_deconv_cfg(deconv_cfg.KERNEL_SIZE[i]) + + layers = [] + layers.append(nn.Sequential( + nn.ConvTranspose2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=deconv_kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False), + nn.BatchNorm2d(output_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + )) + for _ in range(cfg.MODEL.EXTRA.DECONV.NUM_BASIC_BLOCKS): + layers.append(nn.Sequential( + BasicBlock(output_channels, output_channels), + )) + deconv_layers.append(nn.Sequential(*layers)) + input_channels = output_channels + + return nn.ModuleList(deconv_layers) + + def _get_deconv_cfg(self, deconv_kernel): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + final_outputs = [] + x = y_list[0] + y = self.final_layers[0](x) + final_outputs.append(y) + + for i in range(self.num_deconvs): + if self.deconv_config.CAT_OUTPUT[i]: + x = torch.cat((x, y), 1) + + x = self.deconv_layers[i](x) + y = self.final_layers[i+1](x) + final_outputs.append(y) + + return final_outputs + + def init_weights(self, pretrained='', verbose=True): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ['bias']: + nn.init.constant_(m.bias, 0) + + parameters_names = set() + for name, _ in self.named_parameters(): + parameters_names.add(name) + + buffers_names = set() + for name, _ in self.named_buffers(): + buffers_names.add(name) + + if os.path.isfile(pretrained): + pretrained_state_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + + need_init_state_dict = {} + for name, m in pretrained_state_dict.items(): + if name.split('.')[0] in self.pretrained_layers \ + or self.pretrained_layers[0] is '*': + if name in parameters_names or name in buffers_names: + if verbose: + logger.info( + '=> init {} from {}'.format(name, pretrained) + ) + need_init_state_dict[name] = m + self.load_state_dict(need_init_state_dict, strict=False) + + +def get_pose_net(cfg, is_train, **kwargs): + model = PoseHigherResolutionNet(cfg, **kwargs) + + if is_train and cfg.MODEL.INIT_WEIGHTS: + model.init_weights(cfg.MODEL.PRETRAINED, verbose=cfg.VERBOSE) + + return model diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/transforms.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/transforms.py new file mode 100644 index 0000000..8f366d9 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/transforms.py @@ -0,0 +1,202 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 + + +def flip_back(output_flipped, matched_parts): + ''' + ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) + ''' + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + +def fliplr_joints(joints, joints_vis, width, matched_parts): + """ + flip coords + """ + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + + # Change left-right parts + for pair in matched_parts: + joints[pair[0], :], joints[pair[1], :] = \ + joints[pair[1], :], joints[pair[0], :].copy() + joints_vis[pair[0], :], joints_vis[pair[1], :] = \ + joints_vis[pair[1], :], joints_vis[pair[0], :].copy() + + return joints*joints_vis, joints_vis + + +def transform_preds(coords, center, scale, output_size): + # target_coords = np.zeros(coords.shape) + target_coords = coords.copy() + trans = get_affine_transform(center, scale, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + print(scale) + scale = np.array([scale, scale]) + + scale_tmp = scale * 200.0 + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def crop(img, center, scale, output_size, rot=0): + trans = get_affine_transform(center, scale, rot, output_size) + + dst_img = cv2.warpAffine(img, + trans, + (int(output_size[0]), int(output_size[1])), + flags=cv2.INTER_LINEAR) + + return dst_img + + +def resize(image, input_size): + h, w, _ = image.shape + + center = np.array([int(w/2.0+0.5), int(h/2.0+0.5)]) + if w < h: + w_resized = input_size + h_resized = int((input_size / w * h + 63) // 64 * 64) + scale_w = w / 200.0 + scale_h = h_resized / w_resized * w / 200.0 + else: + h_resized = input_size + w_resized = int((input_size / h * w + 63) // 64 * 64) + scale_h = h / 200.0 + scale_w = w_resized / h_resized * h / 200.0 + + scale = np.array([scale_w, scale_h]) + trans = get_affine_transform(center, scale, 0, (w_resized, h_resized)) + + image_resized = cv2.warpAffine( + image, + trans, + (int(w_resized), int(h_resized)) + ) + + return image_resized, center, scale + + +def get_multi_scale_size(image, input_size, current_scale, min_scale): + h, w, _ = image.shape + center = np.array([int(w / 2.0 + 0.5), int(h / 2.0 + 0.5)]) + + # calculate the size for min_scale + min_input_size = int((min_scale * input_size + 63)//64 * 64) + if w < h: + w_resized = int(min_input_size * current_scale / min_scale) + h_resized = int( + int((min_input_size/w*h+63)//64*64)*current_scale/min_scale + ) + scale_w = w / 200.0 + scale_h = h_resized / w_resized * w / 200.0 + else: + h_resized = int(min_input_size * current_scale / min_scale) + w_resized = int( + int((min_input_size/h*w+63)//64*64)*current_scale/min_scale + ) + scale_h = h / 200.0 + scale_w = w_resized / h_resized * h / 200.0 + + return (w_resized, h_resized), center, np.array([scale_w, scale_h]) + + +def resize_align_multi_scale(image, input_size, current_scale, min_scale): + size_resized, center, scale = get_multi_scale_size( + image, input_size, current_scale, min_scale + ) + trans = get_affine_transform(center, scale, 0, size_resized) + + image_resized = cv2.warpAffine( + image, + trans, + size_resized + # (int(w_resized), int(h_resized)) + ) + + return image_resized, center, scale + + +def get_final_preds(grouped_joints, center, scale, heatmap_size): + final_results = [] + for person in grouped_joints[0]: + joints = np.zeros((person.shape[0], 3)) + joints = transform_preds(person, center, scale, heatmap_size) + final_results.append(joints) + + return final_results diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/utils.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/utils.py new file mode 100644 index 0000000..197ded5 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/utils.py @@ -0,0 +1,238 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import time +from collections import namedtuple +from pathlib import Path + +import torch +import torch.optim as optim +import torch.nn as nn + + +def setup_logger(final_output_dir, rank, phase): + time_str = time.strftime('%Y-%m-%d-%H-%M') + log_file = '{}_{}_rank{}.log'.format(phase, time_str, rank) + final_log_file = os.path.join(final_output_dir, log_file) + head = '%(asctime)-15s %(message)s' + # logging.basicConfig(format=head) + logging.basicConfig(filename=str(final_log_file), + format=head) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console = logging.StreamHandler() + logging.getLogger('').addHandler(console) + + return logger, time_str + + +def create_logger(cfg, cfg_name, phase='train'): + root_output_dir = Path(cfg.OUTPUT_DIR) + # set up logger + if not root_output_dir.exists() and cfg.RANK == 0: + print('=> creating {}'.format(root_output_dir)) + root_output_dir.mkdir() + else: + while not root_output_dir.exists(): + print('=> wait for {} created'.format(root_output_dir)) + time.sleep(30) + + dataset = cfg.DATASET.DATASET + dataset = dataset.replace(':', '_') + model = cfg.MODEL.NAME + cfg_name = os.path.basename(cfg_name).split('.')[0] + + final_output_dir = root_output_dir / dataset / model / cfg_name + + if cfg.RANK == 0: + print('=> creating {}'.format(final_output_dir)) + final_output_dir.mkdir(parents=True, exist_ok=True) + else: + while not final_output_dir.exists(): + print('=> wait for {} created'.format(final_output_dir)) + time.sleep(5) + + logger, time_str = setup_logger(final_output_dir, cfg.RANK, phase) + + tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \ + (cfg_name + '_' + time_str) + + print('=> creating {}'.format(tensorboard_log_dir)) + tensorboard_log_dir.mkdir(parents=True, exist_ok=True) + + return logger, str(final_output_dir), str(tensorboard_log_dir) + + +def get_optimizer(cfg, model): + optimizer = None + if cfg.TRAIN.OPTIMIZER == 'sgd': + optimizer = optim.SGD( + model.parameters(), + lr=cfg.TRAIN.LR, + momentum=cfg.TRAIN.MOMENTUM, + weight_decay=cfg.TRAIN.WD, + nesterov=cfg.TRAIN.NESTEROV + ) + elif cfg.TRAIN.OPTIMIZER == 'adam': + optimizer = optim.Adam( + model.parameters(), + lr=cfg.TRAIN.LR + ) + + return optimizer + + +def save_checkpoint(states, is_best, output_dir, + filename='checkpoint.pth.tar'): + torch.save(states, os.path.join(output_dir, filename)) + + if is_best and 'state_dict' in states: + torch.save( + states['best_state_dict'], + os.path.join(output_dir, 'model_best.pth.tar') + ) + + +def get_model_summary(model, *input_tensors, item_length=26, verbose=True): + """ + :param model: + :param input_tensors: + :param item_length: + :return: + """ + + summary = [] + + ModuleDetails = namedtuple( + "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]) + hooks = [] + layer_instances = {} + + def add_hooks(module): + + def hook(module, input, output): + class_name = str(module.__class__.__name__) + + instance_index = 1 + if class_name not in layer_instances: + layer_instances[class_name] = instance_index + else: + instance_index = layer_instances[class_name] + 1 + layer_instances[class_name] = instance_index + + layer_name = class_name + "_" + str(instance_index) + + params = 0 + + if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ + class_name.find("Linear") != -1: + for param_ in module.parameters(): + params += param_.view(-1).size(0) + + flops = "Not Available" + if class_name.find("Conv") != -1 and hasattr(module, "weight"): + flops = ( + torch.prod( + torch.LongTensor(list(module.weight.data.size()))) * + torch.prod( + torch.LongTensor(list(output.size())[2:]))).item() + elif isinstance(module, nn.Linear): + flops = (torch.prod(torch.LongTensor(list(output.size()))) \ + * input[0].size(1)).item() + + if isinstance(input[0], list): + input = input[0] + if isinstance(output, list): + output = output[0] + + summary.append( + ModuleDetails( + name=layer_name, + input_size=list(input[0].size()), + output_size=list(output.size()), + num_parameters=params, + multiply_adds=flops) + ) + + if not isinstance(module, nn.ModuleList) \ + and not isinstance(module, nn.Sequential) \ + and module != model: + hooks.append(module.register_forward_hook(hook)) + + model.eval() + model.apply(add_hooks) + + space_len = item_length + + model(*input_tensors) + for hook in hooks: + hook.remove() + + details = '' + if verbose: + details = "Model Summary" + \ + os.linesep + \ + "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( + ' ' * (space_len - len("Name")), + ' ' * (space_len - len("Input Size")), + ' ' * (space_len - len("Output Size")), + ' ' * (space_len - len("Parameters")), + ' ' * (space_len - len("Multiply Adds (Flops)"))) \ + + os.linesep + '-' * space_len * 5 + os.linesep + params_sum = 0 + flops_sum = 0 + for layer in summary: + params_sum += layer.num_parameters + if layer.multiply_adds != "Not Available": + flops_sum += layer.multiply_adds + if verbose: + details += "{}{}{}{}{}{}{}{}{}{}".format( + layer.name, + ' ' * (space_len - len(layer.name)), + layer.input_size, + ' ' * (space_len - len(str(layer.input_size))), + layer.output_size, + ' ' * (space_len - len(str(layer.output_size))), + layer.num_parameters, + ' ' * (space_len - len(str(layer.num_parameters))), + layer.multiply_adds, + ' ' * (space_len - len(str(layer.multiply_adds)))) \ + + os.linesep + '-' * space_len * 5 + os.linesep + + details += os.linesep \ + + "Total Parameters: {:,}".format(params_sum) \ + + os.linesep + '-' * space_len * 5 + os.linesep + details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,}".format(flops_sum) \ + + os.linesep + '-' * space_len * 5 + os.linesep + details += "Number of Layers" + os.linesep + for layer in layer_instances: + details += "{} : {} layers ".format(layer, layer_instances[layer]) + + return details + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count if self.count != 0 else 0 diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/vis.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/vis.py new file mode 100644 index 0000000..69a1f77 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/vis.py @@ -0,0 +1,238 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import cv2 +import numpy as np +import torchvision + +from dataset import VIS_CONFIG + + +def add_joints(image, joints, color, dataset='COCO'): + part_idx = VIS_CONFIG[dataset]['part_idx'] + part_orders = VIS_CONFIG[dataset]['part_orders'] + + def link(a, b, color): + if part_idx[a] < joints.shape[0] and part_idx[b] < joints.shape[0]: + jointa = joints[part_idx[a]] + jointb = joints[part_idx[b]] + if jointa[2] > 0 and jointb[2] > 0: + cv2.line( + image, + (int(jointa[0]), int(jointa[1])), + (int(jointb[0]), int(jointb[1])), + color, + 2 + ) + + # add joints + for joint in joints: + if joint[2] > 0: + cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) + + # add link + for pair in part_orders: + link(pair[0], pair[1], color) + + return image + + +def save_valid_image(image, joints, file_name, dataset='COCO'): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + for person in joints: + color = np.random.randint(0, 255, size=3) + color = [int(i) for i in color] + add_joints(image, person, color, dataset=dataset) + + cv2.imwrite(file_name, image) + + +def make_heatmaps(image, heatmaps): + heatmaps = heatmaps.mul(255)\ + .clamp(0, 255)\ + .byte()\ + .cpu().numpy() + + num_joints, height, width = heatmaps.shape + image_resized = cv2.resize(image, (int(width), int(height))) + + image_grid = np.zeros((height, (num_joints+1)*width, 3), dtype=np.uint8) + + for j in range(num_joints): + # add_joints(image_resized, joints[:, j, :]) + heatmap = heatmaps[j, :, :] + colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + image_fused = colored_heatmap*0.7 + image_resized*0.3 + + width_begin = width * (j+1) + width_end = width * (j+2) + image_grid[:, width_begin:width_end, :] = image_fused + + image_grid[:, 0:width, :] = image_resized + + return image_grid + + +def make_tagmaps(image, tagmaps): + num_joints, height, width = tagmaps.shape + image_resized = cv2.resize(image, (int(width), int(height))) + + image_grid = np.zeros((height, (num_joints+1)*width, 3), dtype=np.uint8) + + for j in range(num_joints): + tagmap = tagmaps[j, :, :] + min = float(tagmap.min()) + max = float(tagmap.max()) + tagmap = tagmap.add(-min)\ + .div(max - min + 1e-5)\ + .mul(255)\ + .clamp(0, 255)\ + .byte()\ + .cpu()\ + .numpy() + + colored_tagmap = cv2.applyColorMap(tagmap, cv2.COLORMAP_JET) + image_fused = colored_tagmap*0.9 + image_resized*0.1 + + width_begin = width * (j+1) + width_end = width * (j+2) + image_grid[:, width_begin:width_end, :] = image_fused + + image_grid[:, 0:width, :] = image_resized + + return image_grid + + +def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis, + file_name, nrow=8, padding=2): + ''' + batch_image: [batch_size, channel, height, width] + batch_joints: [batch_size, num_joints, 3], + batch_joints_vis: [batch_size, num_joints, 1], + } + ''' + grid = torchvision.utils.make_grid(batch_image, nrow, padding, True) + ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() + ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR) + + nmaps = batch_image.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height = int(batch_image.size(2) + padding) + width = int(batch_image.size(3) + padding) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + joints = batch_joints[k] + joints_vis = batch_joints_vis[k] + + for joint, joint_vis in zip(joints, joints_vis): + joint[0] = x * width + padding + joint[0] + joint[1] = y * height + padding + joint[1] + if joint_vis[0]: + cv2.circle( + ndarr, + (int(joint[0]), int(joint[1])), + 2, + [255, 0, 0], + 2 + ) + k = k + 1 + cv2.imwrite(file_name, ndarr) + + +def save_batch_maps( + batch_image, + batch_maps, + batch_mask, + file_name, + map_type='heatmap', + normalize=True +): + if normalize: + batch_image = batch_image.clone() + min = float(batch_image.min()) + max = float(batch_image.max()) + + batch_image.add_(-min).div_(max - min + 1e-5) + + batch_size = batch_maps.size(0) + num_joints = batch_maps.size(1) + map_height = batch_maps.size(2) + map_width = batch_maps.size(3) + + grid_image = np.zeros( + (batch_size*map_height, (num_joints+1)*map_width, 3), + dtype=np.uint8 + ) + + for i in range(batch_size): + image = batch_image[i].mul(255)\ + .clamp(0, 255)\ + .byte()\ + .permute(1, 2, 0)\ + .cpu().numpy() + + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + maps = batch_maps[i] + + if map_type == 'heatmap': + image_with_hms = make_heatmaps(image, maps) + elif map_type == 'tagmap': + image_with_hms = make_tagmaps(image, maps) + + height_begin = map_height * i + height_end = map_height * (i + 1) + + grid_image[height_begin:height_end, :, :] = image_with_hms + if batch_mask is not None: + mask = np.expand_dims(batch_mask[i].byte().cpu().numpy(), -1) + grid_image[height_begin:height_end, :map_width, :] = \ + grid_image[height_begin:height_end, :map_width, :] * mask + + cv2.imwrite(file_name, grid_image) + + +def save_debug_images( + config, + batch_images, + batch_heatmaps, + batch_masks, + batch_outputs, + prefix +): + if not config.DEBUG.DEBUG: + return + + num_joints = config.DATASET.NUM_JOINTS + batch_pred_heatmaps = batch_outputs[:, :num_joints, :, :] + batch_pred_tagmaps = batch_outputs[:, num_joints:, :, :] + + if config.DEBUG.SAVE_HEATMAPS_GT and batch_heatmaps is not None: + file_name = '{}_hm_gt.jpg'.format(prefix) + save_batch_maps( + batch_images, batch_heatmaps, batch_masks, file_name, 'heatmap' + ) + if config.DEBUG.SAVE_HEATMAPS_PRED: + file_name = '{}_hm_pred.jpg'.format(prefix) + save_batch_maps( + batch_images, batch_pred_heatmaps, batch_masks, file_name, 'heatmap' + ) + if config.DEBUG.SAVE_TAGMAPS_PRED: + file_name = '{}_tag_pred.jpg'.format(prefix) + save_batch_maps( + batch_images, batch_pred_tagmaps, batch_masks, file_name, 'tagmap' + ) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/zipreader.py b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/zipreader.py new file mode 100644 index 0000000..7be1d68 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/lib/utils/zipreader.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import zipfile +import xml.etree.ElementTree as ET + +import cv2 +import numpy as np + +_im_zfile = [] +_xml_path_zip = [] +_xml_zfile = [] + + +def imread(filename, flags=cv2.IMREAD_COLOR): + global _im_zfile + path = filename + pos_at = path.index('@') + if pos_at == -1: + print("character '@' is not found from the given path '%s'"%(path)) + assert 0 + path_zip = path[0: pos_at] + path_img = path[pos_at + 1:] + if not os.path.isfile(path_zip): + print("zip file '%s' is not found"%(path_zip)) + assert 0 + for i in range(len(_im_zfile)): + if _im_zfile[i]['path'] == path_zip: + data = _im_zfile[i]['zipfile'].read(path_img) + return cv2.imdecode(np.frombuffer(data, np.uint8), flags) + + _im_zfile.append({ + 'path': path_zip, + 'zipfile': zipfile.ZipFile(path_zip, 'r') + }) + data = _im_zfile[-1]['zipfile'].read(path_img) + + return cv2.imdecode(np.frombuffer(data, np.uint8), flags) + + +def xmlread(filename): + global _xml_path_zip + global _xml_zfile + path = filename + pos_at = path.index('@') + if pos_at == -1: + print("character '@' is not found from the given path '%s'"%(path)) + assert 0 + path_zip = path[0: pos_at] + path_xml = path[pos_at + 2:] + if not os.path.isfile(path_zip): + print("zip file '%s' is not found"%(path_zip)) + assert 0 + for i in xrange(len(_xml_path_zip)): + if _xml_path_zip[i] == path_zip: + data = _xml_zfile[i].open(path_xml) + return ET.fromstring(data.read()) + _xml_path_zip.append(path_zip) + print("read new xml file '%s'"%(path_zip)) + _xml_zfile.append(zipfile.ZipFile(path_zip, 'r')) + data = _xml_zfile[-1].open(path_xml) + return ET.fromstring(data.read()) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/requirements.txt b/data_processing/HigherHRNet-Human-Pose-Estimation/requirements.txt new file mode 100644 index 0000000..e91f702 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/requirements.txt @@ -0,0 +1,13 @@ +EasyDict==1.7 +opencv-python +Cython +scipy +pandas +pyyaml +json_tricks +scikit-image +tensorboardX +yacs +cffi +munkres +tqdm diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/_init_paths.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/_init_paths.py new file mode 100644 index 0000000..e6fd60e --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/_init_paths.py @@ -0,0 +1,23 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path as osp +import sys + + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + + +this_dir = osp.dirname(__file__) + +lib_path = osp.join(this_dir, '..', 'lib') +add_path(lib_path) diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/crowdpose_concat_train_val.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/crowdpose_concat_train_val.py new file mode 100644 index 0000000..514cf9f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/crowdpose_concat_train_val.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os + + +def parse_args(): + parser = argparse.ArgumentParser(description='Concat CrowdPose train and val') + + parser.add_argument('--data_dir', + help='data directory containing json annotation file', + default='data/crowd_pose/json', + type=str) + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + train_dataset = json.load(open(os.path.join(args.data_dir, 'crowdpose_train.json'))) + val_dataset = json.load(open(os.path.join(args.data_dir, 'crowdpose_val.json'))) + + trainval_dataset = {} + trainval_dataset['categories'] = train_dataset['categories'] + trainval_dataset['images'] = [] + trainval_dataset['images'].extend(train_dataset['images']) + trainval_dataset['images'].extend(val_dataset['images']) + trainval_dataset['annotations'] = [] + trainval_dataset['annotations'].extend(train_dataset['annotations']) + trainval_dataset['annotations'].extend(val_dataset['annotations']) + + with open(os.path.join(args.data_dir, 'crowdpose_trainval.json'), 'w') as f: + json.dump(trainval_dataset, f) + + +if __name__ == '__main__': + main() diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/dist_train.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/dist_train.py new file mode 100644 index 0000000..6d662aa --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/dist_train.py @@ -0,0 +1,319 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import pprint +import shutil +import warnings + +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from tensorboardX import SummaryWriter + +import _init_paths +import models + +from config import cfg +from config import update_config +from core.loss import MultiLossFactory +from core.trainer import do_train +from dataset import make_dataloader +from fp16_utils.fp16util import network_to_half +from fp16_utils.fp16_optimizer import FP16_Optimizer +from utils.utils import create_logger +from utils.utils import get_optimizer +from utils.utils import save_checkpoint +from utils.utils import setup_logger + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train keypoints network') + # general + parser.add_argument('--cfg', + help='experiment configure file name', + required=True, + type=str) + + parser.add_argument('opts', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + + # distributed training + parser.add_argument('--gpu', + help='gpu id for multiprocessing training', + type=str) + parser.add_argument('--world-size', + default=1, + type=int, + help='number of nodes for distributed training') + parser.add_argument('--dist-url', + default='tcp://127.0.0.1:23456', + type=str, + help='url used to set up distributed training') + parser.add_argument('--rank', + default=0, + type=int, + help='node rank for distributed training') + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + update_config(cfg, args) + + cfg.defrost() + cfg.RANK = args.rank + cfg.freeze() + + logger, final_output_dir, tb_log_dir = create_logger( + cfg, args.cfg, 'train' + ) + + logger.info(pprint.pformat(args)) + logger.info(cfg) + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or cfg.MULTIPROCESSING_DISTRIBUTED + + ngpus_per_node = torch.cuda.device_count() + if cfg.MULTIPROCESSING_DISTRIBUTED: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn( + main_worker, + nprocs=ngpus_per_node, + args=(ngpus_per_node, args, final_output_dir, tb_log_dir) + ) + else: + # Simply call main_worker function + main_worker( + ','.join([str(i) for i in cfg.GPUS]), + ngpus_per_node, + args, + final_output_dir, + tb_log_dir + ) + + +def main_worker( + gpu, ngpus_per_node, args, final_output_dir, tb_log_dir +): + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + if cfg.FP16.ENABLED: + assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." + + if cfg.FP16.STATIC_LOSS_SCALE != 1.0: + if not cfg.FP16.ENABLED: + print("Warning: if --fp16 is not used, static_loss_scale will be ignored.") + + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if cfg.MULTIPROCESSING_DISTRIBUTED: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + print('Init process group: dist_url: {}, world_size: {}, rank: {}'. + format(args.dist_url, args.world_size, args.rank)) + dist.init_process_group( + backend=cfg.DIST_BACKEND, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank + ) + + update_config(cfg, args) + + # setup logger + logger, _ = setup_logger(final_output_dir, args.rank, 'train') + + model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=True + ) + + # copy model file + if not cfg.MULTIPROCESSING_DISTRIBUTED or ( + cfg.MULTIPROCESSING_DISTRIBUTED + and args.rank % ngpus_per_node == 0 + ): + this_dir = os.path.dirname(__file__) + shutil.copy2( + os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'), + final_output_dir + ) + + writer_dict = { + 'writer': SummaryWriter(log_dir=tb_log_dir), + 'train_global_steps': 0, + 'valid_global_steps': 0, + } + + if not cfg.MULTIPROCESSING_DISTRIBUTED or ( + cfg.MULTIPROCESSING_DISTRIBUTED + and args.rank % ngpus_per_node == 0 + ): + dump_input = torch.rand( + (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE) + ) + writer_dict['writer'].add_graph(model, (dump_input, )) + # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE)) + + if cfg.FP16.ENABLED: + model = network_to_half(model) + + if cfg.MODEL.SYNC_BN and not args.distributed: + print('Warning: Sync BatchNorm is only supported in distributed training.') + + if args.distributed: + if cfg.MODEL.SYNC_BN: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + # args.workers = int(args.workers / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu] + ) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + model = torch.nn.DataParallel(model).cuda() + + # define loss function (criterion) and optimizer + loss_factory = MultiLossFactory(cfg).cuda() + + # Data loading code + train_loader = make_dataloader( + cfg, is_train=True, distributed=args.distributed + ) + logger.info(train_loader.dataset) + + best_perf = -1 + best_model = False + last_epoch = -1 + optimizer = get_optimizer(cfg, model) + + if cfg.FP16.ENABLED: + optimizer = FP16_Optimizer( + optimizer, + static_loss_scale=cfg.FP16.STATIC_LOSS_SCALE, + dynamic_loss_scale=cfg.FP16.DYNAMIC_LOSS_SCALE + ) + + begin_epoch = cfg.TRAIN.BEGIN_EPOCH + checkpoint_file = os.path.join( + final_output_dir, 'checkpoint.pth.tar') + if cfg.AUTO_RESUME and os.path.exists(checkpoint_file): + logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) + checkpoint = torch.load(checkpoint_file) + begin_epoch = checkpoint['epoch'] + best_perf = checkpoint['perf'] + last_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + + optimizer.load_state_dict(checkpoint['optimizer']) + logger.info("=> loaded checkpoint '{}' (epoch {})".format( + checkpoint_file, checkpoint['epoch'])) + + if cfg.FP16.ENABLED: + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer.optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, + last_epoch=last_epoch + ) + else: + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, + last_epoch=last_epoch + ) + + for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH): + # train one epoch + do_train(cfg, model, train_loader, loss_factory, optimizer, epoch, + final_output_dir, tb_log_dir, writer_dict, fp16=cfg.FP16.ENABLED) + + # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`. + lr_scheduler.step() + + perf_indicator = epoch + if perf_indicator >= best_perf: + best_perf = perf_indicator + best_model = True + else: + best_model = False + + if not cfg.MULTIPROCESSING_DISTRIBUTED or ( + cfg.MULTIPROCESSING_DISTRIBUTED + and args.rank == 0 + ): + logger.info('=> saving checkpoint to {}'.format(final_output_dir)) + save_checkpoint({ + 'epoch': epoch + 1, + 'model': cfg.MODEL.NAME, + 'state_dict': model.state_dict(), + 'best_state_dict': model.module.state_dict(), + 'perf': perf_indicator, + 'optimizer': optimizer.state_dict(), + }, best_model, final_output_dir) + + final_model_state_file = os.path.join( + final_output_dir, 'final_state{}.pth.tar'.format(gpu) + ) + + logger.info('saving final model state to {}'.format( + final_model_state_file)) + torch.save(model.module.state_dict(), final_model_state_file) + writer_dict['writer'].close() + + +if __name__ == '__main__': + main() diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/get_keypoints.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/get_keypoints.py new file mode 100644 index 0000000..8bc0b0c --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/get_keypoints.py @@ -0,0 +1,264 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import json +import argparse +import os +import pprint +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms +import torch.multiprocessing +from tqdm import tqdm + +import _init_paths +import models + +from config import cfg +from config import check_config +from config import update_config +from core.inference import get_multi_stage_outputs +from core.inference import aggregate_results +from core.group import HeatmapParser +# from dataset import make_test_dataloader +from fp16_utils.fp16util import network_to_half +from utils.utils import create_logger +from utils.utils import get_model_summary +# from utils.vis import save_debug_images +from utils.vis import save_valid_image +from utils.transforms import resize_align_multi_scale +from utils.transforms import get_final_preds +from utils.transforms import get_multi_scale_size +import glob +import cv2 +torch.multiprocessing.set_sharing_strategy('file_system') +from urllib.request import urlretrieve,build_opener,install_opener + + +def parse_args(): + parser = argparse.ArgumentParser(description='Test keypoints network') + # general + parser.add_argument('--cfg', + help='experiment configure file name', + required=True, + type=str) + parser.add_argument('--input_dir', + help='experiment configure file name', + required=True, + type=str) + + parser.add_argument('opts', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + return args + + +# markdown format output +def _print_name_value(logger, name_value, full_arch_name): + names = name_value.keys() + values = name_value.values() + num_values = len(name_value) + logger.info( + '| Arch ' + + ' '.join(['| {}'.format(name) for name in names]) + + ' |' + ) + logger.info('|---' * (num_values+1) + '|') + + if len(full_arch_name) > 15: + full_arch_name = full_arch_name[:8] + '...' + logger.info( + '| ' + full_arch_name + ' ' + + ' '.join(['| {:.3f}'.format(value) for value in values]) + + ' |' + ) + + +def main(): + args = parse_args() + update_config(cfg, args) + check_config(cfg) + + logger, final_output_dir, tb_log_dir = create_logger( + cfg, args.cfg, 'valid' + ) + + logger.info(pprint.pformat(args)) + #logger.info(cfg) + + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + dump_input = torch.rand( + (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE) + ) + #logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE)) + + if cfg.FP16.ENABLED: + model = network_to_half(model) + + if cfg.TEST.MODEL_FILE: + logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True) + else: + model_state_file = os.path.join( + final_output_dir, 'model_best.pth.tar' + ) + logger.info('=> loading model from {}'.format(model_state_file)) + model.load_state_dict(torch.load(model_state_file)) + + model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() + model.eval() + + # data_loader, test_dataset = make_test_dataloader(cfg) + + if cfg.MODEL.NAME == 'pose_hourglass': + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + ] + ) + else: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ] + ) + + parser = HeatmapParser(cfg) + results = {} + + json_save_path = os.path.join(args.input_dir,'2d_pose_result_hrnet.json') + print('json_save_path', json_save_path) + + if os.path.exists(json_save_path): + with open(json_save_path, 'r') as f: + results = json.load(f) + + meta_save_path = os.path.join(args.input_dir, 'meta_data.json') + print('meta_save_path', meta_save_path) + + if os.path.exists(meta_save_path): + with open(meta_save_path, 'r') as f: + meta_data = json.load(f) + + + + # pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None + i = 0 + for image_path in tqdm(glob.glob(os.path.join(args.input_dir, 'images','*'))): + if os.path.basename(image_path) in results: + continue + try: + img = cv2.imread( + image_path, + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except: + os.remove(image_path) + + # + # assert 1 == images.size(0), 'Test batch size should be 1' + # + # image = images[0].cpu().numpy() + # size at scale 1.0 + base_size, center, scale = get_multi_scale_size( + image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR) + ) + + with torch.no_grad(): + final_heatmaps = None + tags_list = [] + for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): + input_size = cfg.DATASET.INPUT_SIZE + image_resized, center, scale = resize_align_multi_scale( + image, input_size, s, min(cfg.TEST.SCALE_FACTOR) + ) + image_resized = transforms(image_resized) + image_resized = image_resized.unsqueeze(0).cuda() + + outputs, heatmaps, tags = get_multi_stage_outputs( + cfg, model, image_resized, cfg.TEST.FLIP_TEST, + cfg.TEST.PROJECT2IMAGE, base_size + ) + + final_heatmaps, tags_list = aggregate_results( + cfg, s, final_heatmaps, tags_list, heatmaps, tags + ) + + final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR)) + tags = torch.cat(tags_list, dim=4) + grouped, scores = parser.parse( + final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE + ) + + final_results = get_final_preds( + grouped, center, scale, + [final_heatmaps.size(3), final_heatmaps.size(2)] + ) + + prefix = '{}_{}'.format(os.path.join(final_output_dir, 'result_valid'), i) + # logger.info('=> write {}'.format(prefix)) + #save_valid_image(image, final_results, '{}.jpg'.format(prefix), dataset='COCO') + + # save_debug_images(cfg, image_resized, None, None, outputs, prefix) + + res = np.zeros((len(final_results), 17, 5)) + + for person_id,kpts in enumerate(final_results): + res[person_id, :, :] = kpts + + + results[os.path.basename(image_path)] = res.tolist() + + i += 1 + + + + + + + with open(json_save_path, 'w') as f: + json.dump(results, f) + + + # name_values, _ = test_dataset.evaluate( + # cfg, all_preds, all_scores, final_output_dir + # ) + # + # if isinstance(name_values, list): + # for name_value in name_values: + # _print_name_value(logger, name_value, cfg.MODEL.NAME) + # else: + # _print_name_value(logger, name_values, cfg.MODEL.NAME) + + +if __name__ == '__main__': + main() diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/test.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/test.py new file mode 100644 index 0000000..651bb05 --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/test.py @@ -0,0 +1,231 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import json +import argparse +import os +import pprint +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms +import torch.multiprocessing +from tqdm import tqdm + +import _init_paths +import models + +from config import cfg +from config import check_config +from config import update_config +from core.inference import get_multi_stage_outputs +from core.inference import aggregate_results +from core.group import HeatmapParser +# from dataset import make_test_dataloader +from fp16_utils.fp16util import network_to_half +from utils.utils import create_logger +from utils.utils import get_model_summary +# from utils.vis import save_debug_images +from utils.vis import save_valid_image +from utils.transforms import resize_align_multi_scale +from utils.transforms import get_final_preds +from utils.transforms import get_multi_scale_size +import glob +import cv2 +torch.multiprocessing.set_sharing_strategy('file_system') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Test keypoints network') + # general + parser.add_argument('--cfg', + help='experiment configure file name', + required=True, + type=str) + + parser.add_argument('opts', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + return args + + +# markdown format output +def _print_name_value(logger, name_value, full_arch_name): + names = name_value.keys() + values = name_value.values() + num_values = len(name_value) + logger.info( + '| Arch ' + + ' '.join(['| {}'.format(name) for name in names]) + + ' |' + ) + logger.info('|---' * (num_values+1) + '|') + + if len(full_arch_name) > 15: + full_arch_name = full_arch_name[:8] + '...' + logger.info( + '| ' + full_arch_name + ' ' + + ' '.join(['| {:.3f}'.format(value) for value in values]) + + ' |' + ) + + +def main(): + args = parse_args() + update_config(cfg, args) + check_config(cfg) + + logger, final_output_dir, tb_log_dir = create_logger( + cfg, args.cfg, 'valid' + ) + + logger.info(pprint.pformat(args)) + logger.info(cfg) + + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + dump_input = torch.rand( + (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE) + ) + logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE)) + + if cfg.FP16.ENABLED: + model = network_to_half(model) + + if cfg.TEST.MODEL_FILE: + logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True) + else: + model_state_file = os.path.join( + final_output_dir, 'model_best.pth.tar' + ) + logger.info('=> loading model from {}'.format(model_state_file)) + model.load_state_dict(torch.load(model_state_file)) + + model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() + model.eval() + + # data_loader, test_dataset = make_test_dataloader(cfg) + + if cfg.MODEL.NAME == 'pose_hourglass': + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + ] + ) + else: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ] + ) + + parser = HeatmapParser(cfg) + results = {} + + # pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None + for i, image_path in enumerate(glob.glob('F:/full-head-dataset\skeleton_estimation/3DCrowdNet_RELEASE\demo\my_input\images/*')): + img = cv2.imread( + image_path, + cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION + ) + image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + + # + # assert 1 == images.size(0), 'Test batch size should be 1' + # + # image = images[0].cpu().numpy() + # size at scale 1.0 + base_size, center, scale = get_multi_scale_size( + image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR) + ) + + with torch.no_grad(): + final_heatmaps = None + tags_list = [] + for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): + input_size = cfg.DATASET.INPUT_SIZE + image_resized, center, scale = resize_align_multi_scale( + image, input_size, s, min(cfg.TEST.SCALE_FACTOR) + ) + image_resized = transforms(image_resized) + image_resized = image_resized.unsqueeze(0).cuda() + + outputs, heatmaps, tags = get_multi_stage_outputs( + cfg, model, image_resized, cfg.TEST.FLIP_TEST, + cfg.TEST.PROJECT2IMAGE, base_size + ) + + final_heatmaps, tags_list = aggregate_results( + cfg, s, final_heatmaps, tags_list, heatmaps, tags + ) + + final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR)) + tags = torch.cat(tags_list, dim=4) + grouped, scores = parser.parse( + final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE + ) + + final_results = get_final_preds( + grouped, center, scale, + [final_heatmaps.size(3), final_heatmaps.size(2)] + ) + + prefix = '{}_{}'.format(os.path.join(final_output_dir, 'result_valid'), i) + # logger.info('=> write {}'.format(prefix)) + #save_valid_image(image, final_results, '{}.jpg'.format(prefix), dataset='COCO') + + # save_debug_images(cfg, image_resized, None, None, outputs, prefix) + + res = np.zeros((len(final_results), 17, 5)) + + for person_id,kpts in enumerate(final_results): + res[person_id, :, :] = kpts + + + results[os.path.basename(image_path)] = res.tolist() + + with open('F:/full-head-dataset\skeleton_estimation/3DCrowdNet_RELEASE\demo\my_input/2d_pose_result_hrnet.json', 'w') as f: + json.dump(results, f) + + + # name_values, _ = test_dataset.evaluate( + # cfg, all_preds, all_scores, final_output_dir + # ) + # + # if isinstance(name_values, list): + # for name_value in name_values: + # _print_name_value(logger, name_value, cfg.MODEL.NAME) + # else: + # _print_name_value(logger, name_values, cfg.MODEL.NAME) + + +if __name__ == '__main__': + main() diff --git a/data_processing/HigherHRNet-Human-Pose-Estimation/tools/valid.py b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/valid.py new file mode 100644 index 0000000..97cf22f --- /dev/null +++ b/data_processing/HigherHRNet-Human-Pose-Estimation/tools/valid.py @@ -0,0 +1,220 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import pprint + +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms +import torch.multiprocessing +from tqdm import tqdm + +import _init_paths +import models + +from config import cfg +from config import check_config +from config import update_config +from core.inference import get_multi_stage_outputs +from core.inference import aggregate_results +from core.group import HeatmapParser +from dataset import make_test_dataloader +from fp16_utils.fp16util import network_to_half +from utils.utils import create_logger +from utils.utils import get_model_summary +from utils.vis import save_debug_images +from utils.vis import save_valid_image +from utils.transforms import resize_align_multi_scale +from utils.transforms import get_final_preds +from utils.transforms import get_multi_scale_size + +torch.multiprocessing.set_sharing_strategy('file_system') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Test keypoints network') + # general + parser.add_argument('--cfg', + help='experiment configure file name', + required=True, + type=str) + + parser.add_argument('opts', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + return args + + +# markdown format output +def _print_name_value(logger, name_value, full_arch_name): + names = name_value.keys() + values = name_value.values() + num_values = len(name_value) + logger.info( + '| Arch ' + + ' '.join(['| {}'.format(name) for name in names]) + + ' |' + ) + logger.info('|---' * (num_values+1) + '|') + + if len(full_arch_name) > 15: + full_arch_name = full_arch_name[:8] + '...' + logger.info( + '| ' + full_arch_name + ' ' + + ' '.join(['| {:.3f}'.format(value) for value in values]) + + ' |' + ) + + +def main(): + args = parse_args() + update_config(cfg, args) + check_config(cfg) + + logger, final_output_dir, tb_log_dir = create_logger( + cfg, args.cfg, 'valid' + ) + + logger.info(pprint.pformat(args)) + logger.info(cfg) + + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + dump_input = torch.rand( + (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE) + ) + logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE)) + + if cfg.FP16.ENABLED: + model = network_to_half(model) + + if cfg.TEST.MODEL_FILE: + logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True) + else: + model_state_file = os.path.join( + final_output_dir, 'model_best.pth.tar' + ) + logger.info('=> loading model from {}'.format(model_state_file)) + model.load_state_dict(torch.load(model_state_file)) + + model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() + model.eval() + + data_loader, test_dataset = make_test_dataloader(cfg) + + if cfg.MODEL.NAME == 'pose_hourglass': + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + ] + ) + else: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ] + ) + + parser = HeatmapParser(cfg) + all_preds = [] + all_scores = [] + + pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None + for i, (images, annos) in enumerate(data_loader): + assert 1 == images.size(0), 'Test batch size should be 1' + + image = images[0].cpu().numpy() + # size at scale 1.0 + base_size, center, scale = get_multi_scale_size( + image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR) + ) + + with torch.no_grad(): + final_heatmaps = None + tags_list = [] + for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): + input_size = cfg.DATASET.INPUT_SIZE + image_resized, center, scale = resize_align_multi_scale( + image, input_size, s, min(cfg.TEST.SCALE_FACTOR) + ) + image_resized = transforms(image_resized) + image_resized = image_resized.unsqueeze(0).cuda() + + outputs, heatmaps, tags = get_multi_stage_outputs( + cfg, model, image_resized, cfg.TEST.FLIP_TEST, + cfg.TEST.PROJECT2IMAGE, base_size + ) + + final_heatmaps, tags_list = aggregate_results( + cfg, s, final_heatmaps, tags_list, heatmaps, tags + ) + + final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR)) + tags = torch.cat(tags_list, dim=4) + grouped, scores = parser.parse( + final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE + ) + + final_results = get_final_preds( + grouped, center, scale, + [final_heatmaps.size(3), final_heatmaps.size(2)] + ) + + if cfg.TEST.LOG_PROGRESS: + pbar.update() + + if i % cfg.PRINT_FREQ == 0: + prefix = '{}_{}'.format(os.path.join(final_output_dir, 'result_valid'), i) + # logger.info('=> write {}'.format(prefix)) + save_valid_image(image, final_results, '{}.jpg'.format(prefix), dataset=test_dataset.name) + # save_debug_images(cfg, image_resized, None, None, outputs, prefix) + + all_preds.append(final_results) + all_scores.append(scores) + + if cfg.TEST.LOG_PROGRESS: + pbar.close() + + name_values, _ = test_dataset.evaluate( + cfg, all_preds, all_scores, final_output_dir + ) + + if isinstance(name_values, list): + for name_value in name_values: + _print_name_value(logger, name_value, cfg.MODEL.NAME) + else: + _print_name_value(logger, name_values, cfg.MODEL.NAME) + + +if __name__ == '__main__': + main() diff --git a/data_processing/LICENSE b/data_processing/LICENSE new file mode 100644 index 0000000..9eebacf --- /dev/null +++ b/data_processing/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Hongsuk Choi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/data_processing/MANIQA/config.py b/data_processing/MANIQA/config.py new file mode 100644 index 0000000..49adf80 --- /dev/null +++ b/data_processing/MANIQA/config.py @@ -0,0 +1,12 @@ +import json + +""" configuration json """ +class Config(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + + @classmethod + def load(cls, file): + with open(file, 'r') as f: + config = json.loads(f.read()) + return Config(config) \ No newline at end of file diff --git a/data_processing/MANIQA/delete_images.py b/data_processing/MANIQA/delete_images.py new file mode 100644 index 0000000..6032224 --- /dev/null +++ b/data_processing/MANIQA/delete_images.py @@ -0,0 +1,30 @@ +import os +import glob +import argparse +parser = argparse.ArgumentParser(description='Test keypoints network') + # general + +parser.add_argument('--input_dir', + help='experiment configure file name', + required=True, + type=str) +args = parser.parse_args() +count = 0 +for image_path in glob.glob(os.path.join(args.input_dir, 'aligned_images/*')): + image_name = os.path.basename(image_path) + if not os.path.exists(os.path.join(args.input_dir, 'visualization',image_name)): + os.remove(image_path) + count+=1 +print(count) +count = 0 +#for image_path in glob.glob('G:/full-head-dataset/pexels/00000000/visualization/*'): +for image_path in glob.glob(os.path.join(args.input_dir, 'visualization/*')): + image_name = os.path.basename(image_path) + #if not os.path.exists('G:/full-head-dataset/pexels/00000000/aligned_images/' + image_name): + if not os.path.exists(os.path.join(args.input_dir, 'aligned_images',image_name)): + os.remove(image_path) + count+=1 + + + +print(count) \ No newline at end of file diff --git a/data_processing/MANIQA/delete_raw_images.py b/data_processing/MANIQA/delete_raw_images.py new file mode 100644 index 0000000..0cd7161 --- /dev/null +++ b/data_processing/MANIQA/delete_raw_images.py @@ -0,0 +1,35 @@ +import os +import glob +import argparse +import json +import numpy as np +parser = argparse.ArgumentParser(description='Test keypoints network') + # general + +parser.add_argument('--input_dir', + help='experiment configure file name', + required=True, + type=str) +args = parser.parse_args() +count = 0 +for image_path in glob.glob(os.path.join(args.input_dir, 'images/*')): + image_name = os.path.basename(image_path).split('.')[0] + #print(os.path.join(args.input_dir, 'aligned_images',image_name + '*')) + if len(glob.glob(os.path.join(args.input_dir, 'aligned_images',image_name + '*'))) == 0: + os.remove(image_path) + print(image_path) + count+=1 +print(count) + +json_save_path = os.path.join(args.input_dir, 'result.json') +with open(json_save_path, 'r') as f: + results = json.load(f) +# check: +for image_path in glob.glob(os.path.join(args.input_dir, 'aligned_images/*')): + raw_image_name = results[os.path.basename(image_path)]['raw_image_name'] + if not os.path.exists(os.path.join(args.input_dir, 'images',raw_image_name)): + raise Exception(image_path) + + + + diff --git a/data_processing/MANIQA/imagedups.py b/data_processing/MANIQA/imagedups.py new file mode 100644 index 0000000..2248dc1 --- /dev/null +++ b/data_processing/MANIQA/imagedups.py @@ -0,0 +1,137 @@ +#!python +import argparse +from PIL import Image +import os +import sys +import imagehash +import progressbar +import multiprocessing as mp +import numpy as np +import cv2 +def dupes(config): + hmap = {} + paths = config['paths'] + subdirs = [] + if config['recurse']: + for path in paths: + for root, dirs, _ in os.walk(path): + for name in dirs: + subdirs.append(os.path.join(root, name)) + paths += subdirs + files = [] + for path in paths: + fs = os.listdir(path) + for f in fs: + fpath = os.path.join(path, f) + if os.path.isdir(fpath): + continue + files.append(fpath) + + num_cores = int(mp.cpu_count()) + pool = mp.Pool(num_cores) + manager = mp.Manager() + managed_locker = manager.Lock() + managed_dict = manager.dict() + results = [pool.apply_async(async_hash, args=(fpath, managed_dict, managed_locker)) for fpath in files] + + pbar = progressbar.ProgressBar(max_value=len(files)) + for i, p in enumerate(results): + p.get() + pbar.update(i) + pbar.finish() + + count = 0 + for k, v in managed_dict.items(): + if len(v) == 1: + continue + + # show image in v + if config['show']: + images = [] + for fpath in v: + images.append(Image.open(fpath)) + images = [np.array(image) for image in images] + images = np.concatenate(images, axis=1) + images = cv2.cvtColor(images, cv2.COLOR_RGB2BGR) + images = cv2.resize(images, (images.shape[1] // 4, images.shape[0] // 4)) + cv2.imshow('images', images) + cv2.waitKey(0) + + for idx, fpath in enumerate(v): + if idx == 0: + if not config['quiet']: + #print("[+]", fpath, os.path.getsize(fpath)) + pass + else: + if not config['quiet']: + pass + #print("[-]", fpath, os.path.getsize(fpath)) + + confirm = config['noprompt'] + + + + if not config['noprompt'] and config['delete']: + print("Delete %s? [y/n]") + confirm = sys.stdin.readline().strip() == 'y' + if config['delete'] and confirm: + count += 1 + os.unlink(fpath) + # if not config['quiet']: + # print() + + + print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Deleted %d files" % count) +def async_hash(fpath, result_dict, result_lock): + try: + h = imagehash.average_hash(Image.open(fpath)) + h = "%s" % h + sims = result_dict.get(h, []) + sims.append(fpath) + with result_lock: + result_dict[h] = sims + except Exception as e: + pass + +def main(args=None): + parser = argparse.ArgumentParser( + prog="imagedups", + description="""Find/Delete duplicated images + + imagedups [options] -p DIRECTORY... + """, + epilog=""" + inspire by fdupes + """, formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument('-d', '--delete', dest='delete', default=False, action='store_true', + help='Delete duplicated files, keep one image only') + parser.add_argument('-r', '--recurse', dest='recurse', default=False, action='store_true', + help='For every directory given follow subdirectories encountered within') + parser.add_argument('-N', '--noprompt', dest='noprompt', default=False, action='store_true', + help='''Together with --delete, preserve the first file in +each set of duplicates and delete the rest without +prompting the user + ''') + parser.add_argument('-w', '--show', dest='show', default=False, action='store_true', + help='''Together with --delete, preserve the first file in + each set of duplicates and delete the rest without + prompting the user + ''') + parser.add_argument('-q', '--quiet', dest='quiet', default=False, action='store_true', + help='Hide progress indicator') + parser.add_argument('--minsize', dest='minsize', type=int, + help='Consider only files greater than or equal to SIZE bytes') + parser.add_argument('--maxsize', dest='maxsize', type=int, + help='Consider only files less than or equal to SIZE bytes') + parser.add_argument('-p', '--path', dest='paths', nargs='+', type=str, required=True) + + if args is not None: + config = vars(parser.parse_args(args)) + else: + config = vars(parser.parse_args()) + + dupes(config) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/data_processing/MANIQA/remove_blurr_images.py b/data_processing/MANIQA/remove_blurr_images.py new file mode 100644 index 0000000..d245531 --- /dev/null +++ b/data_processing/MANIQA/remove_blurr_images.py @@ -0,0 +1,34 @@ +import os + +import glob +import cv2 +import shutil +from tqdm import tqdm + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Test keypoints network') + # general + + parser.add_argument('--input_dir', + help='experiment configure file name', + required=True, + type=str) + + args = parser.parse_args() + + + + image_list = glob.glob(os.path.join(args.input_dir, 'aligned_images/*')) + + for image_path in tqdm(image_list): + # data load + # model defination + image = cv2.imread(image_path) + img2gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + imageVar = cv2.Laplacian(img2gray, cv2.CV_64F).var() + if imageVar < 4: + os.remove(image_path) + diff --git a/data_processing/MANIQA/timm/__init__.py b/data_processing/MANIQA/timm/__init__.py new file mode 100644 index 0000000..04ec7e5 --- /dev/null +++ b/data_processing/MANIQA/timm/__init__.py @@ -0,0 +1,4 @@ +from .version import __version__ +from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ + is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ + get_model_default_value, is_model_pretrained diff --git a/data_processing/MANIQA/timm/data/__init__.py b/data_processing/MANIQA/timm/data/__init__.py new file mode 100644 index 0000000..7d3cb2b --- /dev/null +++ b/data_processing/MANIQA/timm/data/__init__.py @@ -0,0 +1,12 @@ +from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ + rand_augment_transform, auto_augment_transform +from .config import resolve_data_config +from .constants import * +from .dataset import ImageDataset, IterableImageDataset, AugMixDataset +from .dataset_factory import create_dataset +from .loader import create_loader +from .mixup import Mixup, FastCollateMixup +from .parsers import create_parser +from .real_labels import RealLabelsImagenet +from .transforms import * +from .transforms_factory import create_transform \ No newline at end of file diff --git a/data_processing/MANIQA/timm/data/constants.py b/data_processing/MANIQA/timm/data/constants.py new file mode 100644 index 0000000..d6d4a01 --- /dev/null +++ b/data_processing/MANIQA/timm/data/constants.py @@ -0,0 +1,7 @@ +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) diff --git a/data_processing/MANIQA/timm/models/__init__.py b/data_processing/MANIQA/timm/models/__init__.py new file mode 100644 index 0000000..2ef4918 --- /dev/null +++ b/data_processing/MANIQA/timm/models/__init__.py @@ -0,0 +1,59 @@ +from .beit import * +from .byoanet import * +from .byobnet import * +from .cait import * +from .coat import * +from .convit import * +from .convmixer import * +from .convnext import * +from .crossvit import * +from .cspnet import * +from .densenet import * +from .dla import * +from .dpn import * +from .efficientnet import * +from .ghostnet import * +from .gluon_resnet import * +from .gluon_xception import * +from .hardcorenas import * +from .hrnet import * +from .inception_resnet_v2 import * +from .inception_v3 import * +from .inception_v4 import * +from .levit import * +from .mlp_mixer import * +from .mobilenetv3 import * +from .nasnet import * +from .nest import * +from .nfnet import * +from .pit import * +from .pnasnet import * +from .regnet import * +from .res2net import * +from .resnest import * +from .resnet import * +from .resnetv2 import * +from .rexnet import * +from .selecsls import * +from .senet import * +from .sknet import * +from .swin_transformer import * +from .tnt import * +from .tresnet import * +from .twins import * +from .vgg import * +from .visformer import * +from .vision_transformer import * +from .vision_transformer_hybrid import * +from .vovnet import * +from .xception import * +from .xception_aligned import * +from .xcit import * + +from .factory import create_model, split_model_name, safe_model_name +from .helpers import load_checkpoint, resume_checkpoint, model_parameters +from .layers import TestTimePoolHead, apply_test_time_pool +from .layers import convert_splitbn_model +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ + has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained diff --git a/data_processing/MANIQA/timm/models/convnext.py b/data_processing/MANIQA/timm/models/convnext.py new file mode 100644 index 0000000..5f75647 --- /dev/null +++ b/data_processing/MANIQA/timm/models/convnext.py @@ -0,0 +1,427 @@ +""" ConvNeXt + +Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + +Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the MIT license +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_module +from .helpers import named_apply, build_model_with_cfg +from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .registry import register_model + + +__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), + convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), + convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), + convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), + + convnext_tiny_hnf=_cfg(url=''), + + convnext_base_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), + convnext_large_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), + convnext_xlarge_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), + + convnext_base_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_large_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_xlarge_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + convnext_base_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), + convnext_large_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), + convnext_xlarge_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), +) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@register_notrace_module +class LayerNorm2d(nn.LayerNorm): + r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__(normalized_shape, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) + x = (x - u) * torch.rsqrt(s + self.eps) + x = x * self.weight[:, None, None] + self.bias[:, None, None] + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + super().__init__() + if not norm_layer: + norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + mlp_layer = ConvMlp if conv_mlp else Mlp + self.use_conv_mlp = conv_mlp + self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = norm_layer(dim) + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtStage(nn.Module): + + def __init__( + self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, + norm_layer=None, cl_norm_layer=None, cross_stage=False): + super().__init__() + + if in_chs != out_chs or stride > 1: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + ) + else: + self.downsample = nn.Identity() + + dp_rates = dp_rates or [0.] * depth + self.blocks = nn.Sequential(*[ConvNeXtBlock( + dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, + norm_layer=norm_layer if conv_mlp else cl_norm_layer) + for j in range(depth)] + ) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_rate (float): Head dropout rate + drop_path_rate (float): Stochastic depth rate. Default: 0. + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, + head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + ): + super().__init__() + assert output_stride == 32 + if norm_layer is None: + norm_layer = partial(LayerNorm2d, eps=1e-6) + cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + else: + assert conv_mlp,\ + 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' + cl_norm_layer = norm_layer + + self.num_classes = num_classes + self.drop_rate = drop_rate + self.feature_info = [] + + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + norm_layer(dims[0]) + ) + + self.stages = nn.Sequential() + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + curr_stride = patch_size + prev_chs = dims[0] + stages = [] + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride + out_chs = dims[i] + stages.append(ConvNeXtStage( + prev_chs, out_chs, stride=stride, + depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, + norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) + ) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + self.num_features = prev_chs + if head_norm_first: + # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) + self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + else: + # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) + self.norm_pre = nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool='avg'): + if isinstance(self.head, ClassifierHead): + # norm -> global pool -> fc + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + else: + # pool -> norm -> fc + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', self.head.norm), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + nn.init.constant_(module.bias, 0) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'model' in state_dict: + state_dict = state_dict['model'] + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_convnext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + ConvNeXt, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def convnext_tiny(pretrained=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) + model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_tiny_hnf(pretrained=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs) + model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_small(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) + model = _create_convnext('convnext_small', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) + return model + + + diff --git a/data_processing/MANIQA/timm/models/crossvit.py b/data_processing/MANIQA/timm/models/crossvit.py new file mode 100644 index 0000000..acf9a47 --- /dev/null +++ b/data_processing/MANIQA/timm/models/crossvit.py @@ -0,0 +1,519 @@ +""" CrossViT Model + +@inproceedings{ + chen2021crossvit, + title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}}, + author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda}, + booktitle={International Conference on Computer Vision (ICCV)}, + year={2021} +} + +Paper link: https://arxiv.org/abs/2103.14899 +Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py + +NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408 + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" + +# Copyright IBM All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +""" +Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +""" +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.hub +from functools import partial +from typing import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, trunc_normal_, _assert +from .registry import register_model +from .vision_transformer import Mlp, Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'), + 'classifier': ('head.0', 'head.1'), + **kwargs + } + + +default_cfgs = { + 'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'), + 'crossvit_15_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_15_dagger_408': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), + 'crossvit_18_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_18_dagger_408': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), + 'crossvit_9_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_base_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'), + 'crossvit_small_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'), + 'crossvit_tiny_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'), +} + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + if multi_conv: + if patch_size[0] == 12: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1), + ) + elif patch_size[0] == 16: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(dim, dim, bias=qkv_bias) + self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # B1C -> B1H(C/H) -> BH1(C/H) + q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + # BNC -> BNH(C/H) -> BHN(C/H) + k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + # BNC -> BNH(C/H) -> BHN(C/H) + v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) + + return x + + +class MultiScaleBlock(nn.Module): + + def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + num_branches = len(dim) + self.num_branches = num_branches + # different branch could have different embedding size, the first one is the base + self.blocks = nn.ModuleList() + for d in range(num_branches): + tmp = [] + for i in range(depth[d]): + tmp.append(Block( + dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) + if len(tmp) != 0: + self.blocks.append(nn.Sequential(*tmp)) + + if len(self.blocks) == 0: + self.blocks = None + + self.projs = nn.ModuleList() + for d in range(num_branches): + if dim[d] == dim[(d + 1) % num_branches] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])] + self.projs.append(nn.Sequential(*tmp)) + + self.fusion = nn.ModuleList() + for d in range(num_branches): + d_ = (d + 1) % num_branches + nh = num_heads[d_] + if depth[-1] == 0: # backward capability: + self.fusion.append( + CrossAttentionBlock( + dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + else: + tmp = [] + for _ in range(depth[-1]): + tmp.append(CrossAttentionBlock( + dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + self.fusion.append(nn.Sequential(*tmp)) + + self.revert_projs = nn.ModuleList() + for d in range(num_branches): + if dim[(d + 1) % num_branches] == dim[d] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(), + nn.Linear(dim[(d + 1) % num_branches], dim[d])] + self.revert_projs.append(nn.Sequential(*tmp)) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + + outs_b = [] + for i, block in enumerate(self.blocks): + outs_b.append(block(x[i])) + + # only take the cls token out + proj_cls_token = torch.jit.annotate(List[torch.Tensor], []) + for i, proj in enumerate(self.projs): + proj_cls_token.append(proj(outs_b[i][:, 0:1, ...])) + + # cross attention + outs = [] + for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)): + tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) + tmp = fusion(tmp) + reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...]) + tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) + outs.append(tmp) + return outs + + +def _compute_num_patches(img_size, patches): + return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] + + +@register_notrace_function +def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript + """ + Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. + Args: + x (Tensor): input image + ss (tuple[int, int]): height and width to scale to + crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False + Returns: + Tensor: the "scaled" image batch tensor + """ + H, W = x.shape[-2:] + if H != ss[0] or W != ss[1]: + if crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) + return x + + +class CrossViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000, + embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.), + qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False, + ): + super().__init__() + + self.num_classes = num_classes + self.img_size = to_2tuple(img_size) + img_scale = to_2tuple(img_scale) + self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale] + self.crop_scale = crop_scale # crop instead of interpolate for scale + num_patches = _compute_num_patches(self.img_size_scaled, patch_size) + self.num_branches = len(patch_size) + self.embed_dim = embed_dim + self.num_features = embed_dim[0] # to pass the tests + self.patch_embed = nn.ModuleList() + + # hard-coded for torch jit script + for i in range(self.num_branches): + setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) + setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) + + for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim): + self.patch_embed.append( + PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) + + self.pos_drop = nn.Dropout(p=drop_rate) + + total_depth = sum([sum(x[-2:]) for x in depth]) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule + dpr_ptr = 0 + self.blocks = nn.ModuleList() + for idx, block_cfg in enumerate(depth): + curr_depth = max(block_cfg[:-1]) + block_cfg[-1] + dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] + blk = MultiScaleBlock( + embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer) + dpr_ptr += curr_depth + self.blocks.append(blk) + + self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) + self.head = nn.ModuleList([ + nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() + for i in range(self.num_branches)]) + + for i in range(self.num_branches): + trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) + trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + out = set() + for i in range(self.num_branches): + out.add(f'cls_token_{i}') + pe = getattr(self, f'pos_embed_{i}', None) + if pe is not None and pe.requires_grad: + out.add(f'pos_embed_{i}') + return out + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.ModuleList( + [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in + range(self.num_branches)]) + + def forward_features(self, x): + B = x.shape[0] + xs = [] + for i, patch_embed in enumerate(self.patch_embed): + x_ = x + ss = self.img_size_scaled[i] + x_ = scale_image(x_, ss, self.crop_scale) + x_ = patch_embed(x_) + cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script + cls_tokens = cls_tokens.expand(B, -1, -1) + x_ = torch.cat((cls_tokens, x_), dim=1) + pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script + x_ = x_ + pos_embed + x_ = self.pos_drop(x_) + xs.append(x_) + + for i, blk in enumerate(self.blocks): + xs = blk(xs) + + # NOTE: was before branch token section, move to here to assure all branch token are before layer norm + xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] + return [xo[:, 0] for xo in xs] + + def forward(self, x): + xs = self.forward_features(x) + ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] + if not isinstance(self.head[0], nn.Identity): + ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) + return ce_logits + + +def _create_crossvit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + def pretrained_filter_fn(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 'pos_embed' in key or 'cls_token' in key: + new_key = key.replace(".", "_") + else: + new_key = key + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + return build_model_with_cfg( + CrossViT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=pretrained_filter_fn, + **kwargs) + + +@register_model +def crossvit_tiny_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_small_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_base_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_dagger_408(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_dagger_408(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args) + return model diff --git a/data_processing/MANIQA/timm/models/cspnet.py b/data_processing/MANIQA/timm/models/cspnet.py new file mode 100644 index 0000000..4feb341 --- /dev/null +++ b/data_processing/MANIQA/timm/models/cspnet.py @@ -0,0 +1,460 @@ +"""PyTorch CspNet + +A PyTorch implementation of Cross Stage Partial Networks including: +* CSPResNet50 +* CSPResNeXt50 +* CSPDarkNet53 +* and DarkNet53 for good measure + +Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + +Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer +from .registry import register_model + + +__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.887, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'cspresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'), + 'cspresnet50d': _cfg(url=''), + 'cspresnet50w': _cfg(url=''), + 'cspresnext50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', + input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl + ), + 'cspresnext50_iabn': _cfg(url=''), + 'cspdarknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), + 'cspdarknet53_iabn': _cfg(url=''), + 'darknet53': _cfg(url=''), +} + + +model_cfgs = dict( + cspresnet50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50d=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50w=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(1.,) * 4, + bottle_ratio=(0.25,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspresnext50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + groups=(32,) * 4, + exp_ratio=(1.,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspdarknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + exp_ratio=(2.,) + (1.,) * 4, + bottle_ratio=(0.5,) + (1.0,) * 4, + block_ratio=(1.,) + (0.5,) * 4, + down_growth=True, + ) + ), + darknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ) +) + + +def create_stem( + in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', + act_layer=None, norm_layer=None, aa_layer=None): + stem = nn.Sequential() + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + assert len(out_chs) + in_c = in_chans + for i, out_c in enumerate(out_chs): + conv_name = f'conv{i + 1}' + stem.add_module(conv_name, ConvBnAct( + in_c, out_c, kernel_size, stride=stride if i == 0 else 1, + act_layer=act_layer, norm_layer=norm_layer)) + in_c = out_c + last_conv = conv_name + if pool: + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=in_c, stride=2)) + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + + +class ResBottleneck(nn.Module): + """ ResNe(X)t Bottleneck Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResBottleneck, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None + self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) + self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None + self.drop_path = drop_path + self.act3 = act_layer(inplace=True) + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn2 is not None: + x = self.attn2(x) + x = self.conv3(x) + if self.attn3 is not None: + x = self.attn3(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + # FIXME partial shortcut needed if first block handled as per original, not used for my current impl + #x[:, :shortcut.size(1)] += shortcut + x = self.act3(x) + return x + + +class DarkBlock(nn.Module): + """ DarkNet Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(DarkBlock, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn = create_attn(attn_layer, channels=out_chs) + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + return x + + +class CrossStage(nn.Module): + """Cross Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., + groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, + block_fn=ResBottleneck, **block_kwargs): + super(CrossStage, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + exp_chs = int(round(out_chs * exp_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + self.conv_down = ConvBnAct( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, + # there is also special case for the first stage for some of the model that results in uneven split + # across the two paths. I did it this way for simplicity for now. + self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # output of conv_exp is always split in two + + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + # transition convs + self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) + self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + if self.conv_down is not None: + x = self.conv_down(x) + x = self.conv_exp(x) + split = x.shape[1] // 2 + xs, xb = x[:, :split], x[:, split:] + xb = self.blocks(xb) + xb = self.conv_transition_b(xb).contiguous() + out = self.conv_transition(torch.cat([xs, xb], dim=1)) + return out + + +class DarkStage(nn.Module): + """DarkNet stage.""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, + first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + super(DarkStage, self).__init__() + first_dilation = first_dilation or dilation + + self.conv_down = ConvBnAct( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), + aa_layer=block_kwargs.get('aa_layer', None)) + + prev_chs = out_chs + block_out_chs = int(round(out_chs * block_ratio)) + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + def forward(self, x): + x = self.conv_down(x) + x = self.blocks(x) + return x + + +def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): + # get per stage args for stage and containing blocks, calculate strides to meet target output_stride + num_stages = len(cfg['depth']) + if 'groups' not in cfg: + cfg['groups'] = (1,) * num_stages + if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): + cfg['down_growth'] = (cfg['down_growth'],) * num_stages + if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): + cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages + cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] + stage_strides = [] + stage_dilations = [] + stage_first_dilations = [] + dilation = 1 + for cfg_stride in cfg['stride']: + stage_first_dilations.append(dilation) + if curr_stride >= output_stride: + dilation *= cfg_stride + stride = 1 + else: + stride = cfg_stride + curr_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + cfg['stride'] = stage_strides + cfg['dilation'] = stage_dilations + cfg['first_dilation'] = stage_first_dilations + stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] + return stage_args + + +class CspNet(nn.Module): + """Cross Stage Partial base model. + + Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks + + NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the + darknet impl. I did it this way for simplicity and less special cases. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., + zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + + # Construct the stem + self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) + self.feature_info = [stem_feat_info] + prev_chs = stem_feat_info['num_chs'] + curr_stride = stem_feat_info['reduction'] # reduction does not include pool + if cfg['stem']['pool']: + curr_stride *= 2 + + # Construct the stages + per_stage_args = _cfg_to_stage_args( + cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) + self.stages = nn.Sequential() + for i, sa in enumerate(per_stage_args): + self.stages.add_module( + str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) + prev_chs = sa['out_chs'] + curr_stride *= sa['stride'] + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + + # Construct the head + self.num_features = prev_chs + self.head = ClassifierHead( + in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_cspnet(variant, pretrained=False, **kwargs): + cfg_variant = variant.split('_')[0] + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) + return build_model_with_cfg( + CspNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + +@register_model +def cspresnet50(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50d(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50w(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50(pretrained=False, **kwargs): + return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) + + +@register_model +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + + +@register_model +def cspdarknet53_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) + + +@register_model +def darknet53(pretrained=False, **kwargs): + return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) diff --git a/data_processing/MANIQA/timm/models/densenet.py b/data_processing/MANIQA/timm/models/densenet.py new file mode 100644 index 0000000..38a1972 --- /dev/null +++ b/data_processing/MANIQA/timm/models/densenet.py @@ -0,0 +1,387 @@ +"""Pytorch Densenet implementation w/ tweaks +This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with +fixed kwargs passthrough and addition of dynamic global avg/max pool. +""" +import re +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.jit.annotations import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier +from .registry import register_model + +__all__ = ['DenseNet'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.conv0', 'classifier': 'classifier', + } + + +default_cfgs = { + 'densenet121': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'), + 'densenet121d': _cfg(url=''), + 'densenetblur121d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'), + 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), + 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), + 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), + 'densenet264': _cfg(url=''), + 'densenet264d_iabn': _cfg(url=''), + 'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'), +} + + +class DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, + drop_rate=0., memory_efficient=False): + super(DenseLayer, self).__init__() + self.add_module('norm1', norm_layer(num_input_features)), + self.add_module('conv1', nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), + self.add_module('conv2', nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + return cp.checkpoint(closure, *x) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, x): # noqa: F811 + if isinstance(x, torch.Tensor): + prev_features = [x] + else: + prev_features = x + + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bottleneck_fn(prev_features) + + new_features = self.conv2(self.norm2(bottleneck_output)) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + drop_rate=0., memory_efficient=False): + super(DenseBlock, self).__init__() + for i in range(num_layers): + layer = DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + super(DenseTransition, self).__init__() + self.add_module('norm', norm_layer(num_input_features)) + self.add_module('conv', nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + if aa_layer is not None: + self.add_module('pool', aa_layer(num_output_features, stride=2)) + else: + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_ + """ + + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', + num_classes=1000, in_chans=3, global_pool='avg', + norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, + aa_stem_only=True): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(DenseNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type # 3x3 deep stem + num_init_features = growth_rate * 2 + if aa_layer is None: + stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + stem_pool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=num_init_features, stride=2)]) + if deep_stem: + stem_chs_1 = stem_chs_2 = growth_rate + if 'tiered' in stem_type: + stem_chs_1 = 3 * (growth_rate // 4) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), + ('norm0', norm_layer(stem_chs_1)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), + ('norm1', norm_layer(stem_chs_2)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), + ('norm2', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + self.feature_info = [ + dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')] + current_stride = 4 + + # DenseBlocks + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) + module_name = f'denseblock{(i + 1)}' + self.features.add_module(module_name, block) + num_features = num_features + num_layers * growth_rate + transition_aa_layer = None if aa_stem_only else aa_layer + if i != len(block_config) - 1: + self.feature_info += [ + dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] + current_stride *= 2 + trans = DenseTransition( + num_input_features=num_features, num_output_features=num_features // 2, + norm_layer=norm_layer, aa_layer=transition_aa_layer) + self.features.add_module(f'transition{i + 1}', trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', norm_layer(num_features)) + + self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] + self.num_features = num_features + + # Linear layer + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + # both classifier and block drop? + # if self.drop_rate > 0.: + # x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def _filter_torchvision_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + +def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): + kwargs['growth_rate'] = growth_rate + kwargs['block_config'] = block_config + return build_model_with_cfg( + DenseNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, + **kwargs) + + +@register_model +def densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenetblur121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', + aa_layer=BlurPool2d, **kwargs) + return model + + +@register_model +def densenet121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet169(pretrained=False, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet201(pretrained=False, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet161(pretrained=False, **kwargs): + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264(pretrained=False, **kwargs): + r"""Densenet-264 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264d_iabn(pretrained=False, **kwargs): + r"""Densenet-264 model with deep stem and Inplace-ABN + """ + def norm_act_fn(num_features, **kwargs): + return create_norm_act('iabn', num_features, **kwargs) + model = _create_densenet( + 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tv_densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model with original Torchvision weights, from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/dla.py b/data_processing/MANIQA/timm/models/dla.py new file mode 100644 index 0000000..f6e4dd2 --- /dev/null +++ b/data_processing/MANIQA/timm/models/dla.py @@ -0,0 +1,443 @@ +""" Deep Layer Aggregation and DLA w/ Res2Net +DLA original adapted from Official Pytorch impl at: +DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484 + +Res2Net additions from: https://github.com/gasvn/Res2Net/ +Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['DLA'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'base_layer.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'dla34': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth'), + 'dla46_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth'), + 'dla46x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth'), + 'dla60x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth'), + 'dla60': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth'), + 'dla60x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x-d15cacda.pth'), + 'dla102': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth'), + 'dla102x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x-ad62be81.pth'), + 'dla102x2': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x2-262837b6.pth'), + 'dla169': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth'), + 'dla60_res2net': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'), + 'dla60_res2next': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'), +} + + +class DlaBasic(nn.Module): + """DLA Basic""" + + def __init__(self, inplanes, planes, stride=1, dilation=1, **_): + super(DlaBasic, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.stride = stride + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaBottleneck(nn.Module): + """DLA/DLA-X Bottleneck""" + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64): + super(DlaBottleneck, self).__init__() + self.stride = stride + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + + self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes) + self.conv2 = nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation, + bias=False, dilation=dilation, groups=cardinality) + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaBottle2neck(nn.Module): + """ Res2Net/Res2NeXT DLA Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py + """ + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4): + super(DlaBottle2neck, self).__init__() + self.is_first = stride > 1 + self.scale = scale + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + self.width = mid_planes + + self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes * scale) + + num_scale_convs = max(1, scale - 1) + convs = [] + bns = [] + for _ in range(num_scale_convs): + convs.append(nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, groups=cardinality, bias=False)) + bns.append(nn.BatchNorm2d(mid_planes)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + + self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + sp = spx[i] if i == 0 or self.is_first else sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaRoot(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, shortcut): + super(DlaRoot, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.shortcut = shortcut + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.shortcut: + x += children[0] + x = self.relu(x) + + return x + + +class DlaTree(nn.Module): + def __init__(self, levels, block, in_channels, out_channels, stride=1, + dilation=1, cardinality=1, base_width=64, + level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): + super(DlaTree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity() + self.project = nn.Identity() + cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, **cargs) + self.tree2 = block(out_channels, out_channels, 1, **cargs) + if in_channels != out_channels: + # NOTE the official impl/weights have project layers in levels > 1 case that are never + # used, I've moved the project layer here to avoid wasted params but old checkpoints will + # need strict=False while loading. + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels)) + else: + cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) + self.tree1 = DlaTree( + levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) + self.tree2 = DlaTree( + levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) + if levels == 1: + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) + self.level_root = level_root + self.root_dim = root_dim + self.levels = levels + + def forward(self, x, shortcut=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) + shortcut = self.project(bottom) + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, shortcut) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, + drop_rate=0.0, global_pool='avg'): + super(DLA, self).__init__() + self.channels = channels + self.num_classes = num_classes + self.cardinality = cardinality + self.base_width = base_width + self.drop_rate = drop_rate + assert output_stride == 32 # FIXME support dilation + + self.base_layer = nn.Sequential( + nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) + self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) + self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) + self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) + self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs) + self.feature_info = [ + dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level + dict(num_chs=channels[1], reduction=2, module='level1'), + dict(num_chs=channels[2], reduction=4, module='level2'), + dict(num_chs=channels[3], reduction=8, module='level3'), + dict(num_chs=channels[4], reduction=16, module='level4'), + dict(num_chs=channels[5], reduction=32, module='level5'), + ] + + self.num_features = channels[-1] + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), + nn.BatchNorm2d(planes), + nn.ReLU(inplace=True)]) + inplanes = planes + return nn.Sequential(*modules) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def forward_features(self, x): + x = self.base_layer(x) + x = self.level0(x) + x = self.level1(x) + x = self.level2(x) + x = self.level3(x) + x = self.level4(x) + x = self.level5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x + + +def _create_dla(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DLA, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=False, + feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), + **kwargs) + + +@register_model +def dla60_res2net(pretrained=False, **kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs) + return _create_dla('dla60_res2net', pretrained, **model_kwargs) + + +@register_model +def dla60_res2next(pretrained=False,**kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs) + return _create_dla('dla60_res2next', pretrained, **model_kwargs) + + +@register_model +def dla34(pretrained=False, **kwargs): # DLA-34 + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], + block=DlaBasic, **kwargs) + return _create_dla('dla34', pretrained, **model_kwargs) + + +@register_model +def dla46_c(pretrained=False, **kwargs): # DLA-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, **kwargs) + return _create_dla('dla46_c', pretrained, **model_kwargs) + + +@register_model +def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla46x_c', pretrained, **model_kwargs) + + +@register_model +def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x_c', pretrained, **model_kwargs) + + +@register_model +def dla60(pretrained=False, **kwargs): # DLA-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, **kwargs) + return _create_dla('dla60', pretrained, **model_kwargs) + + +@register_model +def dla60x(pretrained=False, **kwargs): # DLA-X-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x', pretrained, **model_kwargs) + + +@register_model +def dla102(pretrained=False, **kwargs): # DLA-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, shortcut_root=True, **kwargs) + return _create_dla('dla102', pretrained, **model_kwargs) + + +@register_model +def dla102x(pretrained=False, **kwargs): # DLA-X-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs) + return _create_dla('dla102x', pretrained, **model_kwargs) + + +@register_model +def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs) + return _create_dla('dla102x2', pretrained, **model_kwargs) + + +@register_model +def dla169(pretrained=False, **kwargs): # DLA-169 + model_kwargs = dict( + levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, shortcut_root=True, **kwargs) + return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/dpn.py b/data_processing/MANIQA/timm/models/dpn.py new file mode 100644 index 0000000..c4e380b --- /dev/null +++ b/data_processing/MANIQA/timm/models/dpn.py @@ -0,0 +1,317 @@ +""" PyTorch implementation of DualPathNetworks +Based on original MXNet implementation https://github.com/cypw/DPNs with +many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs. + +This implementation is compatible with the pretrained weights from cypw's MXNet implementation. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['DPN'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD, + 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'dpn68': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), + 'dpn68b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'dpn92': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), + 'dpn98': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), + 'dpn131': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'), + 'dpn107': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth') +} + + +class CatBnAct(nn.Module): + def __init__(self, in_chs, norm_layer=BatchNormAct2d): + super(CatBnAct, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + def forward(self, x): + if isinstance(x, tuple): + x = torch.cat(x, dim=1) + return self.bn(x) + + +class BnActConv2d(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d): + super(BnActConv2d, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups) + + def forward(self, x): + return self.conv(self.bn(x)) + + +class DualPathBlock(nn.Module): + def __init__( + self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): + super(DualPathBlock, self).__init__() + self.num_1x1_c = num_1x1_c + self.inc = inc + self.b = b + if block_type == 'proj': + self.key_stride = 1 + self.has_proj = True + elif block_type == 'down': + self.key_stride = 2 + self.has_proj = True + else: + assert block_type == 'normal' + self.key_stride = 1 + self.has_proj = False + + self.c1x1_w_s1 = None + self.c1x1_w_s2 = None + if self.has_proj: + # Using different member names here to allow easier parameter key matching for conversion + if self.key_stride == 2: + self.c1x1_w_s2 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) + else: + self.c1x1_w_s1 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) + self.c3x3_b = BnActConv2d( + in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups) + if b: + self.c1x1_c = CatBnAct(in_chs=num_3x3_b) + self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1) + self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1) + else: + self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + self.c1x1_c1 = None + self.c1x1_c2 = None + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + pass + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(x, tuple): + x_in = torch.cat(x, dim=1) + else: + x_in = x + if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None: + # self.has_proj == False, torchscript requires condition on module == None + x_s1 = x[0] + x_s2 = x[1] + else: + # self.has_proj == True + if self.c1x1_w_s1 is not None: + # self.key_stride = 1 + x_s = self.c1x1_w_s1(x_in) + else: + # self.key_stride = 2 + x_s = self.c1x1_w_s2(x_in) + x_s1 = x_s[:, :self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c:, :, :] + x_in = self.c1x1_a(x_in) + x_in = self.c3x3_b(x_in) + x_in = self.c1x1_c(x_in) + if self.c1x1_c1 is not None: + # self.b == True, using None check for torchscript compat + out1 = self.c1x1_c1(x_in) + out2 = self.c1x1_c2(x_in) + else: + out1 = x_in[:, :self.num_1x1_c, :, :] + out2 = x_in[:, self.num_1x1_c:, :, :] + resid = x_s1 + out1 + dense = torch.cat([x_s2, out2], dim=1) + return resid, dense + + +class DPN(nn.Module): + def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, + num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): + super(DPN, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.b = b + assert output_stride == 32 # FIXME look into dilation support + norm_layer = partial(BatchNormAct2d, eps=.001) + fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False) + bw_factor = 1 if small else 4 + blocks = OrderedDict() + + # conv1 + blocks['conv1_1'] = ConvBnAct( + in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) + blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] + + # conv2 + bw = 64 * bw_factor + inc = inc_sec[0] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[0] + 1): + blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] + + # conv3 + bw = 128 * bw_factor + inc = inc_sec[1] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[1] + 1): + blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] + + # conv4 + bw = 256 * bw_factor + inc = inc_sec[2] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[2] + 1): + blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] + + # conv5 + bw = 512 * bw_factor + inc = inc_sec[3] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[3] + 1): + blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] + + blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) + + self.num_features = in_chs + self.features = nn.Sequential(blocks) + + # Using 1x1 conv for the FC layer to allow the extra pooling scheme + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + x = self.flatten(x) + return x + + +def _create_dpn(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DPN, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_concat=True, flatten_sequential=True), + **kwargs) + + +@register_model +def dpn68(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn68b(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn92(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=64, k_r=96, groups=32, + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) + return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn98(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=96, k_r=160, groups=40, + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn131(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=160, groups=40, + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn107(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=200, groups=50, + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) + return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/efficientnet.py b/data_processing/MANIQA/timm/models/efficientnet.py new file mode 100644 index 0000000..cb65ffb --- /dev/null +++ b/data_processing/MANIQA/timm/models/efficientnet.py @@ -0,0 +1,2318 @@ +""" The EfficientNet Family in PyTorch + +An implementation of EfficienNet that covers variety of related models with efficient architectures: + +* EfficientNet-V2 + - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* TinyNet + - Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets - https://arxiv.org/abs/2010.14819 + - Definitions & weights borrowed from https://github.com/huawei-noah/CV-Backbones/tree/master/tinynet_pytorch + +* And likely more... + +The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available +by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing +the models and weights open source! + +Hacked together by / Copyright 2019, Ross Wightman +""" +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['EfficientNet', 'EfficientNetFeatures'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mnasnet_050': _cfg(url=''), + 'mnasnet_075': _cfg(url=''), + 'mnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), + 'mnasnet_140': _cfg(url=''), + + 'semnasnet_050': _cfg(url=''), + 'semnasnet_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth'), + 'semnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), + 'semnasnet_140': _cfg(url=''), + 'mnasnet_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_small_lamb-aff75073.pth'), + + 'mobilenetv2_035': _cfg( + url=''), + 'mobilenetv2_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth', + interpolation='bicubic', + ), + 'mobilenetv2_075': _cfg( + url=''), + 'mobilenetv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'), + 'mobilenetv2_110d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'), + 'mobilenetv2_120d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'), + 'mobilenetv2_140': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'), + + 'fbnetc_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + interpolation='bilinear'), + 'spnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + interpolation='bilinear'), + + # NOTE experimenting with alternate attention + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + test_input_size=(3, 256, 256), crop_pct=1.0), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), + 'efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', + input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), + 'efficientnet_b5': _cfg( + url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'efficientnet_b6': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_l2': _cfg( + url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + + 'efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), + 'efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_es_pruned': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_es_pruned75.pth'), + 'efficientnet_el_pruned': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el_pruned70.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_cc_b0_4e': _cfg(url=''), + 'efficientnet_cc_b0_8e': _cfg(url=''), + 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth'), + 'efficientnet_lite1': _cfg( + url='', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2': _cfg( + url='', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3': _cfg( + url='', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + + 'efficientnet_b1_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b2_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b3_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'efficientnetv2_rw_t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'gc_efficientnetv2_rw_t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'efficientnetv2_rw_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_rw_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + + 'efficientnetv2_s': _cfg( + url='', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_m': _cfg( + url='', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + 'efficientnetv2_l': _cfg( + url='', + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'efficientnetv2_xl': _cfg( + url='', + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_l2_ns_475': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), + 'tf_efficientnet_l2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + + 'tf_efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'tf_efficientnet_cc_b0_4e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + + 'tf_efficientnetv2_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_xl_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_xl_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', + input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), + 'tf_efficientnetv2_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', + input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), + 'tf_efficientnetv2_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', + input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), + 'tf_efficientnetv2_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + + 'mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), + 'mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), + 'mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), + 'mixnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'), + 'mixnet_xxl': _cfg(), + + 'tf_mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), + 'tf_mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), + 'tf_mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), + + "tinynet_a": _cfg( + input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth'), + "tinynet_b": _cfg( + input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth'), + "tinynet_c": _cfg( + input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth'), + "tinynet_d": _cfg( + input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth'), + "tinynet_e": _cfg( + input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth'), +} + + +class EfficientNet(nn.Module): + """ (Generic) EfficientNet + + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet-V2 Small, Medium, Large, XL & B0-B3 + * EfficientNet B0-B8, L2 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * MobileNet-V2 + * FBNet C + * Single-Path NAS Pixel1 + + """ + + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, + output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, + se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): + super(EfficientNet, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features) + self.act2 = act_layer(inplace=True) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + super(EfficientNetFeatures, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + feature_location=feature_location) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_effnet(variant, pretrained=False, **kwargs): + features_only = False + model_cls = EfficientNet + kwargs_filter = None + if kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + model_cls = EfficientNetFeatures + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)), + stem_size=32, + fix_stem=fix_stem_head, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu6'), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'swish'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=resolve_act_layer(kwargs, 'relu6'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_base( + variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 base model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + arch_def = [ + ['cn_r1_k3_s1_e1_c16_skip'], + ['er_r2_k3_s2_e4_c32'], + ['er_r2_k3_s2_e4_c48'], + ['ir_r3_k3_s2_e4_c96_se0.25'], + ['ir_r5_k3_s1_e6_c112_se0.25'], + ['ir_r8_k3_s2_e6_c192_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_s( + variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Small model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + + NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model, + before ref the impl was released. + """ + arch_def = [ + ['cn_r2_k3_s1_e1_c24_skip'], + ['er_r4_k3_s2_e4_c48'], + ['er_r4_k3_s2_e4_c64'], + ['ir_r6_k3_s2_e4_c128_se0.25'], + ['ir_r9_k3_s1_e6_c160_se0.25'], + ['ir_r15_k3_s2_e6_c256_se0.25'], + ] + num_features = 1280 + if rw: + # my original variant, based on paper figure differs from the official release + arch_def[0] = ['er_r2_k3_s1_e1_c24'] + arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25'] + num_features = 1792 + + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(num_features), + stem_size=24, + round_chs_fn=round_chs_fn, + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Medium model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r3_k3_s1_e1_c24_skip'], + ['er_r5_k3_s2_e4_c48'], + ['er_r5_k3_s2_e4_c80'], + ['ir_r7_k3_s2_e4_c160_se0.25'], + ['ir_r14_k3_s1_e6_c176_se0.25'], + ['ir_r18_k3_s2_e6_c304_se0.25'], + ['ir_r5_k3_s1_e6_c512_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r7_k3_s2_e4_c64'], + ['er_r7_k3_s2_e4_c96'], + ['ir_r10_k3_s2_e4_c192_se0.25'], + ['ir_r19_k3_s1_e6_c224_se0.25'], + ['ir_r25_k3_s2_e6_c384_se0.25'], + ['ir_r7_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Xtra-Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r8_k3_s2_e4_c64'], + ['er_r8_k3_s2_e4_c96'], + ['ir_r16_k3_s2_e4_c192_se0.25'], + ['ir_r24_k3_s1_e6_c256_se0.25'], + ['ir_r32_k3_s2_e6_c512_se0.25'], + ['ir_r8_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_tinynet( + variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs +): + """Creates a TinyNet model. + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=max(1280, round_channels(1280, model_width, 8, None)), + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=model_width), + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +@register_model +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +@register_model +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +@register_model +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_035(pretrained=False, **kwargs): + """ MobileNet V2 w/ 0.35 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_035', 0.35, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_050(pretrained=False, **kwargs): + """ MobileNet V2 w/ 0.5 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_075(pretrained=False, **kwargs): + """ MobileNet V2 w/ 0.75 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2a(pretrained=False, **kwargs): + """ EfficientNet-B2 @ 288x288 w/ 1.0 test crop""" + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b2(pretrained=pretrained, **kwargs) + + +@register_model +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3a(pretrained=False, **kwargs): + """ EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """ + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b3(pretrained=pretrained, **kwargs) + + +@register_model +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2.""" + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_es_pruned(pretrained=False, **kwargs): + """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_es_pruned', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_el_pruned(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_el_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1_pruned(pretrained=False, **kwargs): + """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + variant = 'efficientnet_b1_pruned' + model = _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2_pruned(pretrained=False, **kwargs): + """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_pruned(pretrained=False, **kwargs): + """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_t(pretrained=False, **kwargs): + """ EfficientNet-V2 Tiny (Custom variant, tiny not in paper). """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, rw=False, pretrained=pretrained, **kwargs) + return model + + +@register_model +def gc_efficientnetv2_rw_t(pretrained=False, **kwargs): + """ EfficientNet-V2 Tiny w/ Global Context Attn (Custom variant, tiny not in paper). """ + model = _gen_efficientnetv2_s( + 'gc_efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, + rw=False, se_layer='gc', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small (RW variant). + NOTE: This is my initial (pre official code release) w/ some differences. + See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding + """ + model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium (RW variant). + """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. """ + model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. """ + model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. """ + model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_xl(pretrained=False, **kwargs): + """ EfficientNet-V2 Xtra-Large. """ + model = _gen_efficientnetv2_xl('efficientnetv2_xl', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + + +@register_model +def tf_efficientnetv2_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_xl_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Xtra-Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_xl_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Xtra-Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b0(pretrained=False, **kwargs): + """ EfficientNet-V2-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b1(pretrained=False, **kwargs): + """ EfficientNet-V2-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b2(pretrained=False, **kwargs): + """ EfficientNet-V2-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b3(pretrained=False, **kwargs): + """ EfficientNet-V2-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_a(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_a', 1.0, 1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_b(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_c(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_d(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tinynet_e(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/efficientnet_blocks.py b/data_processing/MANIQA/timm/models/efficientnet_blocks.py new file mode 100644 index 0000000..aef6629 --- /dev/null +++ b/data_processing/MANIQA/timm/models/efficientnet_blocks.py @@ -0,0 +1,323 @@ +""" EfficientNet, MobileNetV3, etc Blocks + +Hacked together by / Copyright 2019, Ross Wightman +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .layers import create_conv2d, drop_path, make_divisible, create_act_layer +from .layers.activations import sigmoid + +__all__ = [ + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] + + +class SqueezeExcite(nn.Module): + """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family + + Args: + in_chs (int): input channels to layer + rd_ratio (float): ratio of squeeze reduction + act_layer (nn.Module): activation layer of containing block + gate_layer (Callable): attention gate function + force_act_layer (nn.Module): override block's activation fn if this is set/bound + rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs + """ + + def __init__( + self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): + super(SqueezeExcite, self).__init__() + if rd_channels is None: + rd_round_fn = rd_round_fn or round + rd_channels = rd_round_fn(in_chs * rd_ratio) + act_layer = force_act_layer or act_layer + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate(x_se) + + +class ConvBnAct(nn.Module): + """ Conv + Norm Layer + Activation w/ optional skip connection + """ + def __init__( + self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', + skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): + super(ConvBnAct, self).__init__() + self.has_residual = skip and stride == 1 and in_chs == out_chs + self.drop_path_rate = drop_path_rate + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs) + self.act1 = act_layer(inplace=True) + + def feature_info(self, location): + if location == 'expansion': # output of conv after act, same as block coutput + info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv.out_channels) + return info + + def forward(self, x): + shortcut = x + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + se_layer=None, drop_path_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_path_rate = drop_path_rate + + self.conv_dw = create_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() + + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) + return info + + def forward(self, x): + shortcut = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE + + Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often + referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in + * MNasNet - https://arxiv.org/abs/1807.11626 + * EfficientNet - https://arxiv.org/abs/1905.11946 + * MobileNet-V3 - https://arxiv.org/abs/1905.02244 + """ + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + super(InvertedResidual, self).__init__() + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs) + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + shortcut = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + shortcut = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride + + Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML` + - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html + + This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers + * MobileDet - https://arxiv.org/abs/2004.14525 + * EfficientNet-X - https://arxiv.org/abs/2102.05610 + * EfficientNet-V2 - https://arxiv.org/abs/2104.00298 + """ + + def __init__( + self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + super(EdgeResidual, self).__init__() + if force_in_chs > 0: + mid_chs = make_divisible(force_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Expansion convolution + self.conv_exp = create_conv2d( + in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(mid_chs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs) + + def feature_info(self, location): + if location == 'expansion': # after SE, before PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + shortcut = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + + return x diff --git a/data_processing/MANIQA/timm/models/efficientnet_builder.py b/data_processing/MANIQA/timm/models/efficientnet_builder.py new file mode 100644 index 0000000..8192f4e --- /dev/null +++ b/data_processing/MANIQA/timm/models/efficientnet_builder.py @@ -0,0 +1,463 @@ +""" EfficientNet, MobileNetV3, etc Builder + +Assembles EfficieNet and related network feature blocks from string definitions. +Handles stride, dilation calculations, and selects feature extraction points. + +Hacked together by / Copyright 2019, Ross Wightman +""" + +import logging +import math +import re +from copy import deepcopy +from functools import partial + +import torch.nn as nn + +from .efficientnet_blocks import * +from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible + +__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", + 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] + +_logger = logging.getLogger(__name__) + + +_DEBUG_BUILDER = False + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +def resolve_act_layer(kwargs, default='relu'): + return get_act_layer(kwargs.pop('act_layer', default)) + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit) + + +def _log_info_if(msg, condition): + if condition: + _logger.info(msg) + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + skip = None + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + skip = False # force no skip connection + elif op == 'skip': + skip = True # force a skip connection + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') # aka SiLU + elif v == 'mi': + value = get_act_layer('mish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + noskip=skip is False, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or skip is False, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + force_in_chs=force_in_chs, + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + noskip=skip is False, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + skip=skip is True, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + if isinstance(depth_multiplier, tuple): + assert len(depth_multiplier) == len(arch_def) + else: + depth_multiplier = (depth_multiplier,) * len(arch_def) + for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc)) + return arch_args + + +class EfficientNetBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False, + act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): + self.output_stride = output_stride + self.pad_type = pad_type + self.round_chs_fn = round_chs_fn + self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs + self.act_layer = act_layer + self.norm_layer = norm_layer + self.se_layer = get_attn(se_layer) + try: + self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg + self.se_has_ratio = True + except TypeError: + self.se_has_ratio = False + self.drop_path_rate = drop_path_rate + if feature_location == 'depthwise': + # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense + _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") + feature_location = 'expansion' + self.feature_location = feature_location + assert feature_location in ('bottleneck', 'expansion', '') + self.verbose = _DEBUG_BUILDER + + # state updated during build, consumed by model + self.in_chs = None + self.features = [] + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self.round_chs_fn(ba['out_chs']) + if 'force_in_chs' in ba and ba['force_in_chs']: + # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl + ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + ba['norm_layer'] = self.norm_layer + ba['drop_path_rate'] = drop_path_rate + if bt != 'cn': + se_ratio = ba.pop('se_ratio') + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer + + if bt == 'ir': + _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = EdgeResidual(**ba) + elif bt == 'cn': + _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + stages = [] + if model_block_args[0][0]['stride'] > 1: + # if the first block starts with a stride, we need to extract first level feat from stem + feature_info = dict( + module='act1', num_chs=in_chs, stage=0, reduction=current_stride, + hook_type='forward' if self.feature_location != 'bottleneck' else '') + self.features.append(feature_info) + + # outer list of block_args defines the stacks + for stack_idx, stack_args in enumerate(model_block_args): + last_stack = stack_idx + 1 == len(model_block_args) + _log_info_if('Stack: {}'.format(stack_idx), self.verbose) + assert isinstance(stack_args, list) + + blocks = [] + # each stack (stage of blocks) contains a list of block arguments + for block_idx, block_args in enumerate(stack_args): + last_block = block_idx + 1 == len(stack_args) + _log_info_if(' Block: {}'.format(block_idx), self.verbose) + + assert block_args['stride'] in (1, 2) + if block_idx >= 1: # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + extract_features = False + if last_block: + next_stack_idx = stack_idx + 1 + extract_features = next_stack_idx >= len(model_block_args) or \ + model_block_args[next_stack_idx][0]['stride'] > 1 + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride), self.verbose) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_info = dict( + stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) + module_name = f'blocks.{stack_idx}.{block_idx}' + leaf_name = feature_info.get('module', '') + feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name + self.features.append(feature_info) + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def _init_weight_goog(m, n='', fix_group_fanout=True): + """ Weight initialization as per Tensorflow official implementations. + + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs + + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) + + +def efficientnet_init_weights(model: nn.Module, init_fn=None): + init_fn = init_fn or _init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n) + diff --git a/data_processing/MANIQA/timm/models/factory.py b/data_processing/MANIQA/timm/models/factory.py new file mode 100644 index 0000000..6d3fd98 --- /dev/null +++ b/data_processing/MANIQA/timm/models/factory.py @@ -0,0 +1,79 @@ +from .registry import is_model, is_model_in_modules, model_entrypoint +from .helpers import load_checkpoint +from .layers import set_layer_config +from .hub import load_model_config_from_hf + + +def split_model_name(model_name): + model_split = model_name.split(':', 1) + if len(model_split) == 1: + return '', model_split[0] + else: + source_name, model_name = model_split + assert source_name in ('timm', 'hf_hub') + return source_name, model_name + + +def safe_model_name(model_name, remove_source=True): + def make_safe(name): + return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') + if remove_source: + model_name = split_model_name(model_name)[-1] + return make_safe(model_name) + + +def create_model( + model_name, + pretrained=False, + checkpoint_path='', + scriptable=None, + exportable=None, + no_jit=None, + **kwargs): + """Create a model + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + checkpoint_path (str): path of checkpoint to load after model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are model specific + """ + source_name, model_name = split_model_name(model_name) + + # handle backwards compat with drop_connect -> drop_path change + drop_connect_rate = kwargs.pop('drop_connect_rate', None) + if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: + print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." + " Setting drop_path to %f." % drop_connect_rate) + kwargs['drop_path_rate'] = drop_connect_rate + + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if source_name == 'hf_hub': + # For model names specified in the form `hf_hub:path/architecture_name#revision`, + # load model weights + default_cfg from Hugging Face hub. + hf_default_cfg, model_name = load_model_config_from_hf(model_name) + kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday + + if is_model(model_name): + create_fn = model_entrypoint(model_name) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn(pretrained=pretrained, **kwargs) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/data_processing/MANIQA/timm/models/features.py b/data_processing/MANIQA/timm/models/features.py new file mode 100644 index 0000000..b1d6890 --- /dev/null +++ b/data_processing/MANIQA/timm/models/features.py @@ -0,0 +1,284 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/data_processing/MANIQA/timm/models/fx_features.py b/data_processing/MANIQA/timm/models/fx_features.py new file mode 100644 index 0000000..5a25ee3 --- /dev/null +++ b/data_processing/MANIQA/timm/models/fx_features.py @@ -0,0 +1,73 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable +from torch import nn + +from .features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath +from .layers.non_local_attn import BilinearAttnTransform +from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BatchNormAct2d, # reason: flow control for jit scripting + BilinearAttnTransform, # reason: flow control t <= 1 + BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + DropPath, # reason: TypeError: rand recieved Proxy in `size` argument +} + +try: + from .layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_notrace_module(module: nn.Module): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +class FeatureGraphNet(nn.Module): + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + def forward(self, x): + return list(self.graph_module(x).values()) diff --git a/data_processing/MANIQA/timm/models/ghostnet.py b/data_processing/MANIQA/timm/models/ghostnet.py new file mode 100644 index 0000000..3b6f90a --- /dev/null +++ b/data_processing/MANIQA/timm/models/ghostnet.py @@ -0,0 +1,276 @@ +""" +An implementation of GhostNet Model as defined in: +GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907 +The train script of the model is similar to that of MobileNetV3 +Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch +""" +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import SelectAdaptivePool2d, Linear, make_divisible +from .efficientnet_blocks import SqueezeExcite, ConvBnAct +from .helpers import build_model_with_cfg +from .registry import register_model + + +__all__ = ['GhostNet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'ghostnet_050': _cfg(url=''), + 'ghostnet_100': _cfg( + url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), + 'ghostnet_130': _cfg(url=''), +} + + +_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) + + +class GhostModule(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :] + + +class GhostBottleneck(nn.Module): + """ Ghost bottleneck w/ optional SE""" + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.conv_dw = None + self.bn_dw = None + + # Squeeze-and-excitation + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None + + # Point-wise linear projection + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + shortcut = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.conv_dw is not None: + x = self.conv_dw(x) + x = self.bn_dw(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(shortcut) + return x + + +class GhostNet(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' + self.cfgs = cfgs + self.num_classes = num_classes + self.dropout = dropout + self.feature_info = [] + + # building first layer + stem_chs = make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) + self.bn1 = nn.BatchNorm2d(stem_chs) + self.act1 = nn.ReLU(inplace=True) + prev_chs = stem_chs + + # building inverted residual blocks + stages = nn.ModuleList([]) + block = GhostBottleneck + stage_idx = 0 + net_stride = 2 + for cfg in self.cfgs: + layers = [] + s = 1 + for k, exp_size, c, se_ratio, s in cfg: + out_chs = make_divisible(c * width, 4) + mid_chs = make_divisible(exp_size * width, 4) + layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio)) + prev_chs = out_chs + if s > 1: + net_stride *= 2 + self.feature_info.append(dict( + num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}')) + stages.append(nn.Sequential(*layers)) + stage_idx += 1 + + out_chs = make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + self.pool_dim = prev_chs = out_chs + + self.blocks = nn.Sequential(*stages) + + # building last several layers + self.num_features = out_chs = 1280 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.flatten(x) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + + +def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1] + ], + # stage5 + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1] + ] + ] + model_kwargs = dict( + cfgs=cfgs, + width=width, + **kwargs, + ) + return build_model_with_cfg( + GhostNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **model_kwargs) + + +@register_model +def ghostnet_050(pretrained=False, **kwargs): + """ GhostNet-0.5x """ + model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def ghostnet_100(pretrained=False, **kwargs): + """ GhostNet-1.0x """ + model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def ghostnet_130(pretrained=False, **kwargs): + """ GhostNet-1.3x """ + model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/gluon_resnet.py b/data_processing/MANIQA/timm/models/gluon_resnet.py new file mode 100644 index 0000000..027a10b --- /dev/null +++ b/data_processing/MANIQA/timm/models/gluon_resnet.py @@ -0,0 +1,248 @@ +"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants +This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions +and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) +by Ross Wightman +""" + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SEModule +from .registry import register_model +from .resnet import ResNet, Bottleneck, BasicBlock + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'), + 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'), + 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), + 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), + 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), + 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), + 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), + 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), + 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), + 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), + 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), + 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), + 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), +} + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def gluon_resnet18_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet34_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnet101_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt50-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_64x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-64x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_senet154(pretrained=False, **kwargs): + """Constructs an SENet-154 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_senet154', pretrained, **model_args) diff --git a/data_processing/MANIQA/timm/models/gluon_xception.py b/data_processing/MANIQA/timm/models/gluon_xception.py new file mode 100644 index 0000000..fbd668a --- /dev/null +++ b/data_processing/MANIQA/timm/models/gluon_xception.py @@ -0,0 +1,246 @@ +"""Pytorch impl of Gluon Xception +This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. + +Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) +Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier, get_padding +from .registry import register_model + +__all__ = ['Xception65'] + +default_cfgs = { + 'gluon_xception65': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', + 'input_size': (3, 299, 299), + 'crop_pct': 0.903, + 'pool_size': (10, 10), + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + }, +} + +""" PADDING NOTES +The original PyTorch and Gluon impl of these models dutifully reproduced the +aligned padding added to Tensorflow models for Deeplab. This padding was compensating +for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. +""" + + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + padding = get_padding(kernel_size, stride, dilation) + self.conv_dw = nn.Conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=inplanes, bias=bias) + self.bn = norm_layer(num_features=inplanes) + # pointwise convolution + self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn(x) + x = self.conv_pw(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None): + super(Block, self).__init__() + if isinstance(planes, (list, tuple)): + assert len(planes) == 3 + else: + planes = (planes,) * 3 + outplanes = planes[-1] + + if outplanes != inplanes or stride != 1: + self.skip = nn.Sequential() + self.skip.add_module('conv1', nn.Conv2d( + inplanes, outplanes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=outplanes)) + else: + self.skip = None + + rep = OrderedDict() + for i in range(3): + rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) + rep['conv%d' % (i + 1)] = SeparableConv2d( + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer) + rep['bn%d' % (i + 1)] = norm_layer(planes[i]) + inplanes = planes[i] + + if not start_with_relu: + del rep['act1'] + else: + rep['act1'] = nn.ReLU(inplace=False) + self.rep = nn.Sequential(rep) + + def forward(self, x): + skip = x + if self.skip is not None: + skip = self.skip(skip) + x = self.rep(x) + skip + return x + + +class Xception65(nn.Module): + """Modified Aligned Xception. + + NOTE: only the 65 layer version is included here, the 71 layer variant + was not correct and had no pretrained weights + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + drop_rate=0., global_pool='avg'): + super(Xception65, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_dilation = 1 + exit_dilation = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_dilation = 1 + exit_dilation = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_dilation = 2 + exit_dilation = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block1_act = nn.ReLU(inplace=True) + self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer) + self.block20_act = nn.ReLU(inplace=True) + + self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn3 = norm_layer(num_features=1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn4 = norm_layer(num_features=1536) + self.act4 = nn.ReLU(inplace=True) + + self.num_features = 2048 + self.conv5 = SeparableConv2d( + 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn5 = norm_layer(num_features=self.num_features) + self.act5 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block1_act'), + dict(num_chs=256, reduction=8, module='block3.rep.act1'), + dict(num_chs=728, reduction=16, module='block20.rep.act1'), + dict(num_chs=2048, reduction=32, module='act5'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block1_act(x) + # c1 = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.block20_act(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.act5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_gluon_xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception65, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) + + +@register_model +def gluon_xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/hardcorenas.py b/data_processing/MANIQA/timm/models/hardcorenas.py new file mode 100644 index 0000000..9988a04 --- /dev/null +++ b/data_processing/MANIQA/timm/models/hardcorenas.py @@ -0,0 +1,152 @@ +from functools import partial + +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import get_act_fn +from .mobilenetv3 import MobileNetV3, MobileNetV3Features +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hardcorenas_a': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'), + 'hardcorenas_b': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'), + 'hardcorenas_c': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'), + 'hardcorenas_d': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'), + 'hardcorenas_e': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'), + 'hardcorenas_f': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'), +} + + +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): + """Creates a hardcorenas model + + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS + Paper: https://arxiv.org/abs/2102.11646 + + """ + num_features = 1280 + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=32, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=se_layer, + **kwargs, + ) + + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if model_kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hardcorenas_a(pretrained=False, **kwargs): + """ hardcorenas_A """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_b(pretrained=False, **kwargs): + """ hardcorenas_B """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_c(pretrained=False, **kwargs): + """ hardcorenas_C """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', + 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_d(pretrained=False, **kwargs): + """ hardcorenas_D """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_e(pretrained=False, **kwargs): + """ hardcorenas_E """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_f(pretrained=False, **kwargs): + """ hardcorenas_F """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k3_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/helpers.py b/data_processing/MANIQA/timm/models/helpers.py new file mode 100644 index 0000000..880fcc6 --- /dev/null +++ b/data_processing/MANIQA/timm/models/helpers.py @@ -0,0 +1,518 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +import math +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Callable, Optional, Tuple + +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url + +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .fx_features import FeatureGraphNet +from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from .layers import Conv2dSame, Linear + + +_logger = logging.getLogger(__name__) + + +def load_state_dict(checkpoint_path, use_ema=False): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' + if isinstance(checkpoint, dict): + if use_ema and checkpoint.get('state_dict_ema', None) is not None: + state_dict_key = 'state_dict_ema' + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + # strip `module.` prefix + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + state_dict = new_state_dict + else: + state_dict = checkpoint + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict, strict=strict) + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + default_cfg (dict): Default pretrained model cfg + load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + progress (bool, optional): whether or not to display a progress bar to stderr. Default: False + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. Default: False + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + if not pretrained_url: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) + + if load_fn is not None: + load_fn(model, cached_file) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(cached_file) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset + num_classes (int): num_classes for model + in_chans (int): in_chans for model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + progress (bool): enable progress bar for weight download + + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + hf_hub_id = default_cfg.get('hf_hub', None) + if not pretrained_url and not hf_hub_id: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if pretrained_url: + _logger.info(f'Loading pretrained weights from url ({pretrained_url})') + state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') + elif hf_hub_id and has_hf_hub(necessary=True): + _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') + state_dict = load_state_dict_from_hf(hf_hub_id) + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = default_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = default_cfg.get('classifier', None) + label_offset = default_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != default_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + state_dict.pop(classifier_name + '.weight', None) + state_dict.pop(classifier_name + '.bias', None) + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + if isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + if isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) + + +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def set_default_kwargs(kwargs, names, default_cfg): + for n in names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # default_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = default_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, default_cfg[n]) + + +def filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs + could/should be replaced by an improved configuration mechanism + + Args: + default_cfg: input default_cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Overlay default cfg values from `external_default_cfg` if it exists in kwargs + overlay_external_default_cfg(default_cfg, kwargs) + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + default_cfg: dict, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + pretrained_custom_load: bool = False, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + default_cfg (dict): model's default pretrained/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + default_cfg = deepcopy(default_cfg) if default_cfg else {} + update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) + default_cfg.setdefault('architecture', variant) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Build the model + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.default_cfg = default_cfg + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_custom_load: + load_custom_pretrained(model) + else: + load_pretrained( + model, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg + + return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module diff --git a/data_processing/MANIQA/timm/models/hrnet.py b/data_processing/MANIQA/timm/models/hrnet.py new file mode 100644 index 0000000..c56964f --- /dev/null +++ b/data_processing/MANIQA/timm/models/hrnet.py @@ -0,0 +1,836 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" +import logging +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureInfo +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_classifier +from .registry import register_model +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE + +_BN_MOMENTUM = 0.1 +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hrnet_w18_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), + 'hrnet_w18_small_v2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), + 'hrnet_w18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), + 'hrnet_w30': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), + 'hrnet_w32': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), + 'hrnet_w40': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), + 'hrnet_w44': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), + 'hrnet_w48': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), + 'hrnet_w64': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), +} + +cfg_cls = dict( + hrnet_w18_small=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18_small_v2=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w30=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w32=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w40=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w44=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w48=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w64=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(64, 128), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(64, 128, 256), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(64, 128, 256, 512), + FUSE_METHOD='SUM', + ), + ) +) + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.fuse_act = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + error_msg = '' + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + elif num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) + elif num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels)) + if error_msg: + _logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)] + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return nn.Identity() + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(nn.Identity()) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x: List[torch.Tensor]): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i, branch in enumerate(self.branches): + x[i] = branch(x[i]) + + x_fuse = [] + for i, fuse_outer in enumerate(self.fuse_layers): + y = x[0] if i == 0 else fuse_outer[0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + fuse_outer[j](x[j]) + x_fuse.append(self.fuse_act(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'): + super(HighResolutionNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + stem_width = cfg['STEM_WIDTH'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.act2 = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + self.head = head + self.head_channels = None # set if _make_head called + if head == 'classification': + # Classification Head + self.num_features = 2048 + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + elif head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) + else: + self.incre_modules = None + self.num_features = 256 + + curr_stride = 2 + # module names aren't actually valid here, hook or FeatureNet based extraction would not work + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] + for i, c in enumerate(self.head_channels if self.head_channels else num_channels): + curr_stride *= 2 + c = c * 4 if self.head_channels else c # head block expansion factor of 4 + self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] + + self.init_weights() + + def _make_head(self, pre_stage_channels, incre_only=False): + head_block = Bottleneck + self.head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + if incre_only: + return incre_modules, None, None + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = self.head_channels[i] * head_block.expansion + out_channels = self.head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.head_channels[3] * head_block.expansion, + out_channels=self.num_features, kernel_size=1, stride=1, padding=0 + ), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(inplanes, planes, stride, downsample)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + reset_multi_scale_output = multi_scale_output or i < num_modules - 1 + modules.append(HighResolutionModule( + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def stages(self, x) -> List[torch.Tensor]: + x = self.layer1(x) + + xl = [t(x) for i, t in enumerate(self.transition1)] + yl = self.stage2(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)] + yl = self.stage3(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)] + yl = self.stage4(xl) + return yl + + def forward_features(self, x): + # Stem + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + # Stages + yl = self.stages(x) + + # Classification Head + y = self.incre_modules[0](yl[0]) + for i, down in enumerate(self.downsamp_modules): + y = self.incre_modules[i + 1](yl[i + 1]) + down(y) + y = self.final_layer(y) + return y + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +class HighResolutionNetFeatures(HighResolutionNet): + """HighResolutionNet feature extraction + + The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so. + It would be more complicated to use the FeatureNet helpers. + + The `feature_location=incre` allows grabbing increased channel count features using part of the + classification head. If `feature_location=''` the default HRNet features are returned. First stem + conv is used for stride 2 features. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, + feature_location='incre', out_indices=(0, 1, 2, 3, 4)): + assert feature_location in ('incre', '') + super(HighResolutionNetFeatures, self).__init__( + cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, + drop_rate=drop_rate, head=feature_location) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + self._out_idx = {i for i in out_indices} + + def forward_features(self, x): + assert False, 'Not supported' + + def forward(self, x) -> List[torch.tensor]: + out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if 0 in self._out_idx: + out.append(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + x = self.stages(x) + if self.incre_modules is not None: + x = [incre(f) for f, incre in zip(x, self.incre_modules)] + for i, f in enumerate(x): + if i + 1 in self._out_idx: + out.append(f) + return out + + +def _create_hrnet(variant, pretrained, **model_kwargs): + model_cls = HighResolutionNet + features_only = False + kwargs_filter = None + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg_cls[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hrnet_w18_small(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) + + +@register_model +def hrnet_w18_small_v2(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) + + +@register_model +def hrnet_w18(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18', pretrained, **kwargs) + + +@register_model +def hrnet_w30(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w30', pretrained, **kwargs) + + +@register_model +def hrnet_w32(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w32', pretrained, **kwargs) + + +@register_model +def hrnet_w40(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w40', pretrained, **kwargs) + + +@register_model +def hrnet_w44(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w44', pretrained, **kwargs) + + +@register_model +def hrnet_w48(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w48', pretrained, **kwargs) + + +@register_model +def hrnet_w64(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w64', pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/hub.py b/data_processing/MANIQA/timm/models/hub.py new file mode 100644 index 0000000..65e7ba9 --- /dev/null +++ b/data_processing/MANIQA/timm/models/hub.py @@ -0,0 +1,171 @@ +import json +import logging +import os +from functools import partial +from pathlib import Path +from typing import Union + +import torch +from torch.hub import HASH_REGEX, download_url_to_file, urlparse +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +try: + from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url + cached_download = partial(cached_download, library_name="timm", library_version=__version__) + _has_hf_hub = True +except ImportError: + cached_download = None + _has_hf_hub = False + +_logger = logging.getLogger(__name__) + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def hf_split(hf_id): + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + url = hf_hub_url(hf_model_id, filename, revision=hf_revision) + return cached_download(url, cache_dir=get_cache_dir('hf')) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + default_cfg = load_cfg_from_json(cached_file) + default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation + model_name = default_cfg.get('architecture') + return default_cfg, model_name + + +def load_state_dict_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict + + +def save_for_hf(model, save_directory, model_config=None): + assert has_hf_hub(True) + model_config = model_config or {} + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / 'pytorch_model.bin' + torch.save(model.state_dict(), weights_path) + + config_path = save_directory / 'config.json' + hf_config = model.default_cfg + hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) + hf_config['num_features'] = model_config.pop('num_features', model.num_features) + hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) + hf_config.update(model_config) + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def push_to_hf_hub( + model, + local_dir, + repo_namespace_or_url=None, + commit_message='Add model', + use_auth_token=True, + git_email=None, + git_user=None, + revision=None, + model_config=None, +): + if repo_namespace_or_url: + repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] + else: + if isinstance(use_auth_token, str): + token = use_auth_token + else: + token = HfFolder.get_token() + + if token is None: + raise ValueError( + "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " + "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " + "token as the `use_auth_token` argument." + ) + + repo_owner = HfApi().whoami(token)['name'] + repo_name = Path(local_dir).name + + repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' + + repo = Repository( + local_dir, + clone_from=repo_url, + use_auth_token=use_auth_token, + git_user=git_user, + git_email=git_email, + revision=revision, + ) + + # Prepare a default model card that includes the necessary tags to enable inference. + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' + with repo.commit(commit_message): + # Save model weights and config. + save_for_hf(model, repo.local_dir, model_config=model_config) + + # Save a model card if it doesn't exist. + readme_path = Path(repo.local_dir) / 'README.md' + if not readme_path.exists(): + readme_path.write_text(readme_text) + + return repo.git_remote_url() diff --git a/data_processing/MANIQA/timm/models/inception_resnet_v2.py b/data_processing/MANIQA/timm/models/inception_resnet_v2.py new file mode 100644 index 0000000..7167284 --- /dev/null +++ b/data_processing/MANIQA/timm/models/inception_resnet_v2.py @@ -0,0 +1,358 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionResnetV2'] + +default_cfgs = { + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + }, + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'ens_adv_inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=.001) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, no_relu=False): + super(Block8, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.relu = None if no_relu else nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if self.relu is not None: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + super(InceptionResnetV2, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + assert output_stride == 32 + + self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(no_relu=True) + self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classif(x) + return x + + +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionResnetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def inception_resnet_v2(pretrained=False, **kwargs): + r"""InceptionResnetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." ` paper. + """ + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + + +@register_model +def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): + r""" Ensemble Adversarially trained InceptionResnetV2 model architecture + As per https://arxiv.org/abs/1705.07204 and + https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. + """ + return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/inception_v3.py b/data_processing/MANIQA/timm/models/inception_v3.py new file mode 100644 index 0000000..cbb1107 --- /dev/null +++ b/data_processing/MANIQA/timm/models/inception_v3.py @@ -0,0 +1,470 @@ +""" Inception-V3 + +Originally from torchvision Inception3 model +Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import trunc_normal_, create_classifier, Linear + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # original PyTorch weights, ported from Tensorflow but modified + 'inception_v3': _cfg( + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + has_aux=True), # checkpoint has aux logit layer weights + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + 'tf_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1000, has_aux=False, label_offset=1), + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + 'adv_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1000, has_aux=False, label_offset=1), + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + 'gluon_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + has_aux=False, + ) +} + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.aux_logits = aux_logits + + self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + else: + self.AuxLogits = None + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), + dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), + dict(num_chs=288, reduction=8, module='Mixed_5d'), + dict(num_chs=768, reduction=16, module='Mixed_6e'), + dict(num_chs=2048, reduction=32, module='Mixed_7c'), + ] + + self.num_features = 2048 + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_preaux(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.Pool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.Pool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + return x + + def forward_postaux(self, x): + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x + + def forward_features(self, x): + x = self.forward_preaux(x) + x = self.forward_postaux(x) + return x + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class InceptionV3Aux(InceptionV3): + """InceptionV3 with AuxLogits + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): + super(InceptionV3Aux, self).__init__( + num_classes, in_chans, drop_rate, global_pool, aux_logits) + + def forward_features(self, x): + x = self.forward_preaux(x) + aux = self.AuxLogits(x) if self.training else None + x = self.forward_postaux(x) + return x, aux + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x, aux + + +def _create_inception_v3(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + assert not kwargs.pop('features_only', False) + model_cls = InceptionV3Aux + load_strict = default_cfg['has_aux'] + else: + model_cls = InceptionV3 + load_strict = not default_cfg['has_aux'] + return build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfg, + pretrained_strict=load_strict, + **kwargs) + + +@register_model +def inception_v3(pretrained=False, **kwargs): + # original PyTorch weights, ported from Tensorflow but modified + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def adv_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def gluon_inception_v3(pretrained=False, **kwargs): + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/inception_v4.py b/data_processing/MANIQA/timm/models/inception_v4.py new file mode 100644 index 0000000..cc899e1 --- /dev/null +++ b/data_processing/MANIQA/timm/models/inception_v4.py @@ -0,0 +1,316 @@ +""" Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionV4'] + +default_cfgs = { + 'inception_v4': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed3a(nn.Module): + def __init__(self): + super(Mixed3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed4a(nn.Module): + def __init__(self): + super(Mixed4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed5a(nn.Module): + def __init__(self): + super(Mixed5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class InceptionA(nn.Module): + def __init__(self): + super(InceptionA, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionA(nn.Module): + def __init__(self): + super(ReductionA, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionB(nn.Module): + def __init__(self): + super(InceptionB, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionB(nn.Module): + def __init__(self): + super(ReductionB, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionC(nn.Module): + def __init__(self): + super(InceptionC, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + super(InceptionV4, self).__init__() + assert output_stride == 32 + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + + self.features = nn.Sequential( + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), + ) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='features.2'), + dict(num_chs=160, reduction=4, module='features.3'), + dict(num_chs=384, reduction=8, module='features.9'), + dict(num_chs=1024, reduction=16, module='features.17'), + dict(num_chs=1536, reduction=32, module='features.21'), + ] + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def inception_v4(pretrained=False, **kwargs): + return _create_inception_v4('inception_v4', pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/layers/__init__.py b/data_processing/MANIQA/timm/models/layers/__init__.py new file mode 100644 index 0000000..706d9dc --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/__init__.py @@ -0,0 +1,40 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .trace_utils import _assert, _float_to_int +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/data_processing/MANIQA/timm/models/layers/activations.py b/data_processing/MANIQA/timm/models/layers/activations.py new file mode 100644 index 0000000..e16b3bd --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/data_processing/MANIQA/timm/models/layers/activations_jit.py b/data_processing/MANIQA/timm/models/layers/activations_jit.py new file mode 100644 index 0000000..b4a5165 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/data_processing/MANIQA/timm/models/layers/activations_me.py b/data_processing/MANIQA/timm/models/layers/activations_me.py new file mode 100644 index 0000000..9a12bb7 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/activations_me.py @@ -0,0 +1,218 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/data_processing/MANIQA/timm/models/layers/adaptive_avgmax_pool.py b/data_processing/MANIQA/timm/models/layers/adaptive_avgmax_pool.py new file mode 100644 index 0000000..ebc6ada --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/adaptive_avgmax_pool.py @@ -0,0 +1,118 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/data_processing/MANIQA/timm/models/layers/attention_pool2d.py b/data_processing/MANIQA/timm/models/layers/attention_pool2d.py new file mode 100644 index 0000000..66e49b8 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/attention_pool2d.py @@ -0,0 +1,182 @@ +""" Attention Pool 2D + +Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. + +Based on idea in CLIP by OpenAI, licensed Apache 2.0 +https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +from typing import List, Union, Tuple + +import torch +import torch.nn as nn + +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + def __init__(self, dim, max_freq=4): + super().__init__() + self.dim = dim + self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) + + def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): + """ + NOTE: shape arg should include spatial dim only + """ + device = device or self.bands.device + dtype = dtype or self.bands.dtype + if not isinstance(shape, torch.Size): + shape = torch.Size(shape) + N = shape.numel() + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) + emb = grid * math.pi * self.bands + sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) + cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) + return sin, cos + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) + + +class RotAttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + assert embed_dim % num_heads == 0 + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pos_embed = RotaryEmbedding(self.head_dim) + + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) + x = x.reshape(B, -1, N).permute(0, 2, 1) + + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + + qc, q = q[:, :, :1], q[:, :, 1:] + q = apply_rot_embed(q, sin_emb, cos_emb) + q = torch.cat([qc, q], dim=2) + + kc, k = k[:, :, :1], k[:, :, 1:] + k = apply_rot_embed(k, sin_emb, cos_emb) + k = torch.cat([kc, k], dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +class AttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + It was based on impl in CLIP by OpenAI + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. + """ + def __init__( + self, + in_features: int, + feat_size: Union[int, Tuple[int, int]], + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.feat_size = to_2tuple(feat_size) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + + spatial_dim = self.feat_size[0] * self.feat_size[1] + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + trunc_normal_(self.pos_embed, std=in_features ** -0.5) + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + assert self.feat_size[0] == H + assert self.feat_size[1] == W + x = x.reshape(B, -1, N).permute(0, 2, 1) + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] diff --git a/data_processing/MANIQA/timm/models/layers/blur_pool.py b/data_processing/MANIQA/timm/models/layers/blur_pool.py new file mode 100644 index 0000000..ca4ce75 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/blur_pool.py @@ -0,0 +1,42 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) diff --git a/data_processing/MANIQA/timm/models/layers/bottleneck_attn.py b/data_processing/MANIQA/timm/models/layers/bottleneck_attn.py new file mode 100644 index 0000000..c3db464 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/bottleneck_attn.py @@ -0,0 +1,157 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') + + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) + + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/data_processing/MANIQA/timm/models/layers/cbam.py b/data_processing/MANIQA/timm/models/layers/cbam.py new file mode 100644 index 0000000..bacf5cf --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/cbam.py @@ -0,0 +1,112 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(ChannelAttn, self).__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * F.sigmoid(x_attn) + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class CbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(CbamModule, self).__init__() + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/data_processing/MANIQA/timm/models/layers/classifier.py b/data_processing/MANIQA/timm/models/layers/classifier.py new file mode 100644 index 0000000..798748d --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/classifier.py @@ -0,0 +1,54 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d + + +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + else: + fc = nn.Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/cond_conv2d.py b/data_processing/MANIQA/timm/models/layers/cond_conv2d.py new file mode 100644 index 0000000..8b4bbca --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/cond_conv2d.py @@ -0,0 +1,122 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/data_processing/MANIQA/timm/models/layers/config.py b/data_processing/MANIQA/timm/models/layers/config.py new file mode 100644 index 0000000..f07b9d7 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/data_processing/MANIQA/timm/models/layers/conv2d_same.py b/data_processing/MANIQA/timm/models/layers/conv2d_same.py new file mode 100644 index 0000000..75f0f98 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/data_processing/MANIQA/timm/models/layers/conv_bn_act.py b/data_processing/MANIQA/timm/models/layers/conv_bn_act.py new file mode 100644 index 0000000..33005c3 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/conv_bn_act.py @@ -0,0 +1,40 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, + drop_block=None): + super(ConvBnAct, self).__init__() + use_aa = aa_layer is not None + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.aa is not None: + x = self.aa(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/create_act.py b/data_processing/MANIQA/timm/models/layers/create_act.py new file mode 100644 index 0000000..aa55769 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/create_act.py @@ -0,0 +1,153 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Union, Callable, Type + +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +def get_act_fn(name: Union[Callable, str] = 'relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, type): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/data_processing/MANIQA/timm/models/layers/create_attn.py b/data_processing/MANIQA/timm/models/layers/create_attn.py new file mode 100644 index 0000000..028c0f7 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/create_attn.py @@ -0,0 +1,89 @@ +""" Attention Factory + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from functools import partial + +from .bottleneck_attn import BottleneckAttn +from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn +from .squeeze_excite import SEModule, EffectiveSEModule + + +def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext + elif attn_type == 'gca': + module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) + if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels + return module_cls(channels, **kwargs) + return None diff --git a/data_processing/MANIQA/timm/models/layers/create_conv2d.py b/data_processing/MANIQA/timm/models/layers/create_conv2d.py new file mode 100644 index 0000000..3a0cc03 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/create_conv2d.py @@ -0,0 +1,31 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + assert 'groups' not in kwargs # MixedConv groups are defined by kernel list + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/data_processing/MANIQA/timm/models/layers/create_norm_act.py b/data_processing/MANIQA/timm/models/layers/create_norm_act.py new file mode 100644 index 0000000..5b56294 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/create_norm_act.py @@ -0,0 +1,83 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +import torch +import torch.nn as nn + +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .inplace_abn import InplaceAbn + +_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} +_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type + + +def get_norm_act_layer(layer_class): + layer_class = layer_class.replace('_', '').lower() + if layer_class.startswith("batchnorm"): + layer = BatchNormAct2d + elif layer_class.startswith("groupnorm"): + layer = GroupNormAct + elif layer_class == "evonormbatch": + layer = EvoNormBatch2d + elif layer_class == "evonormsample": + layer = EvoNormSample2d + elif layer_class == "iabn" or layer_class == "inplaceabn": + layer = InplaceAbn + else: + assert False, "Invalid norm_act layer (%s)" % layer_class + return layer + + +def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu + assert len(layer_parts) in (1, 2) + layer = get_norm_act_layer(layer_parts[0]) + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def convert_norm_act(norm_layer, act_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + norm_act_layer = get_norm_act_layer(norm_layer) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/data_processing/MANIQA/timm/models/layers/drop.py b/data_processing/MANIQA/timm/models/layers/drop.py new file mode 100644 index 0000000..90c1933 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/drop.py @@ -0,0 +1,164 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + block_mask = torch.empty_like(x).bernoulli_(gamma) + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.empty_like(x).normal_() + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/data_processing/MANIQA/timm/models/layers/eca.py b/data_processing/MANIQA/timm/models/layers/eca.py new file mode 100644 index 0000000..e29be6a --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/eca.py @@ -0,0 +1,145 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): + super(EcaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +EfficientChannelAttn = EcaModule # alias + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + super(CecaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule diff --git a/data_processing/MANIQA/timm/models/layers/evo_norm.py b/data_processing/MANIQA/timm/models/layers/evo_norm.py new file mode 100644 index 0000000..6ef0c88 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/evo_norm.py @@ -0,0 +1,81 @@ +"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch + +An attempt at getting decent performing EvoNorms running in PyTorch. +While currently faster than other impl, still quite a ways off the built-in BN +in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). + +Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn + +from .trace_utils import _assert + + +class EvoNormBatch2d(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): + super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_type = x.dtype + if self.v is not None: + running_var = self.running_var.view(1, -1, 1, 1) + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) + else: + var = running_var + v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) + d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) + x = x / d + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + + +class EvoNormSample2d(nn.Module): + def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None): + super(EvoNormSample2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + B, C, H, W = x.shape + _assert(C % self.groups == 0, '') + if self.v is not None: + n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() + x = x.reshape(B, self.groups, -1) + x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + x = x.reshape(B, C, H, W) + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) diff --git a/data_processing/MANIQA/timm/models/layers/gather_excite.py b/data_processing/MANIQA/timm/models/layers/gather_excite.py new file mode 100644 index 0000000..2d60dc9 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/data_processing/MANIQA/timm/models/layers/global_context.py b/data_processing/MANIQA/timm/models/layers/global_context.py new file mode 100644 index 0000000..de7fb5c --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/data_processing/MANIQA/timm/models/layers/halo_attn.py b/data_processing/MANIQA/timm/models/layers/halo_attn.py new file mode 100644 index 0000000..f2ac64f --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/halo_attn.py @@ -0,0 +1,233 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. +The attention mechanism works but it's slow as implemented. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import make_divisible +from .weight_init import trunc_normal_ +from .trace_utils import _assert + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + scale_pos_embed (bool): scale the position embedding as well as Q @ K + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + assert stride in (1, 2) + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + self.block_size = self.block_size_ds = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride + + # FIXME not clear if this stride behaviour is what the paper intended + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + # unfold + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) + kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks + # fold + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) + return out + + +""" Three alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" diff --git a/data_processing/MANIQA/timm/models/layers/helpers.py b/data_processing/MANIQA/timm/models/layers/helpers.py new file mode 100644 index 0000000..cc54ca7 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v diff --git a/data_processing/MANIQA/timm/models/layers/inplace_abn.py b/data_processing/MANIQA/timm/models/layers/inplace_abn.py new file mode 100644 index 0000000..3aae7cf --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_block=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/data_processing/MANIQA/timm/models/layers/lambda_layer.py b/data_processing/MANIQA/timm/models/layers/lambda_layer.py new file mode 100644 index 0000000..e50b43c --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/lambda_layer.py @@ -0,0 +1,133 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + +from .helpers import to_2tuple, make_divisible +from .weight_init import trunc_normal_ + + +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + +class LambdaLayer(nn.Module): + """Lambda Layer + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + """ + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.num_heads = num_heads + self.dim_v = dim_out // num_heads + + self.qkv = nn.Conv2d( + dim, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos_indices = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out diff --git a/data_processing/MANIQA/timm/models/layers/linear.py b/data_processing/MANIQA/timm/models/layers/linear.py new file mode 100644 index 0000000..38fe338 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/data_processing/MANIQA/timm/models/layers/median_pool.py b/data_processing/MANIQA/timm/models/layers/median_pool.py new file mode 100644 index 0000000..40bd71a --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/data_processing/MANIQA/timm/models/layers/mixed_conv2d.py b/data_processing/MANIQA/timm/models/layers/mixed_conv2d.py new file mode 100644 index 0000000..fa0ce56 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/data_processing/MANIQA/timm/models/layers/mlp.py b/data_processing/MANIQA/timm/models/layers/mlp.py new file mode 100644 index 0000000..a85e28d --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/mlp.py @@ -0,0 +1,119 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/non_local_attn.py b/data_processing/MANIQA/timm/models/layers/non_local_attn.py new file mode 100644 index 0000000..881fa36 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/non_local_attn.py @@ -0,0 +1,145 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible +from .trace_utils import _assert + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t: int): + B, C, block_size, block_size1 = x.shape + _assert(block_size == block_size1, '') + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/data_processing/MANIQA/timm/models/layers/norm.py b/data_processing/MANIQA/timm/models/layers/norm.py new file mode 100644 index 0000000..8529742 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/norm.py @@ -0,0 +1,24 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__(num_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) diff --git a/data_processing/MANIQA/timm/models/layers/norm_act.py b/data_processing/MANIQA/timm/models/layers/norm_act.py new file mode 100644 index 0000000..2e15181 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/norm_act.py @@ -0,0 +1,85 @@ +""" Normalization + Activation Layers +""" +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .create_act import get_act_layer + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _forward_jit(self, x): + """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function + """ + # exponential_average_factor is self.momentum set to + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + return x + + @torch.jit.ignore + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + if torch.jit.is_scripting(): + x = self._forward_jit(x) + else: + x = self._forward_python(x) + x = self.act(x) + return x + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.act(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/padding.py b/data_processing/MANIQA/timm/models/layers/padding.py new file mode 100644 index 0000000..34afc37 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/data_processing/MANIQA/timm/models/layers/patch_embed.py b/data_processing/MANIQA/timm/models/layers/patch_embed.py new file mode 100644 index 0000000..6a7face --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/patch_embed.py @@ -0,0 +1,39 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .helpers import to_2tuple +from .trace_utils import _assert + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/pool2d_same.py b/data_processing/MANIQA/timm/models/layers/pool2d_same.py new file mode 100644 index 0000000..4c2a1c4 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/pool2d_same.py @@ -0,0 +1,73 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/data_processing/MANIQA/timm/models/layers/selective_kernel.py b/data_processing/MANIQA/timm/models/layers/selective_kernel.py new file mode 100644 index 0000000..1aeb929 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/selective_kernel.py @@ -0,0 +1,120 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible +from .trace_utils import _assert + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + _assert(x.shape[1] == self.num_paths, '') + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernel(nn.Module): + + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + rd_ratio (int, float): reduction factor for attention features + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + drop_block (nn.Module): drop block module + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + """ + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/data_processing/MANIQA/timm/models/layers/separable_conv.py b/data_processing/MANIQA/timm/models/layers/separable_conv.py new file mode 100644 index 0000000..1ddcb4e --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/separable_conv.py @@ -0,0 +1,73 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class SeparableConvBnAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_block=None): + super(SeparableConvBnAct, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/data_processing/MANIQA/timm/models/layers/space_to_depth.py b/data_processing/MANIQA/timm/models/layers/space_to_depth.py new file mode 100644 index 0000000..a7e8e0b --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/data_processing/MANIQA/timm/models/layers/split_attn.py b/data_processing/MANIQA/timm/models/layers/split_attn.py new file mode 100644 index 0000000..dde601b --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/split_attn.py @@ -0,0 +1,85 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .helpers import make_divisible + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) + """ + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels + self.radix = radix + self.drop_block = drop_block + mid_chs = out_channels * radix + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + + padding = kernel_size // 2 if padding is None else padding + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.fc1(x_gap) + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/data_processing/MANIQA/timm/models/layers/split_batchnorm.py b/data_processing/MANIQA/timm/models/layers/split_batchnorm.py new file mode 100644 index 0000000..830781b --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/data_processing/MANIQA/timm/models/layers/squeeze_excite.py b/data_processing/MANIQA/timm/models/layers/squeeze_excite.py new file mode 100644 index 0000000..e5da29e --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/data_processing/MANIQA/timm/models/layers/std_conv.py b/data_processing/MANIQA/timm/models/layers/std_conv.py new file mode 100644 index 0000000..d896ba5 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/std_conv.py @@ -0,0 +1,133 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .padding import get_padding, get_padding_value, pad_same + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.reshape(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/data_processing/MANIQA/timm/models/layers/test_time_pool.py b/data_processing/MANIQA/timm/models/layers/test_time_pool.py new file mode 100644 index 0000000..98c0bf5 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/test_time_pool.py @@ -0,0 +1,52 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config, use_test_size=True): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if use_test_size and 'test_input_size' in model.default_cfg: + df_input_size = model.default_cfg['test_input_size'] + else: + df_input_size = model.default_cfg['input_size'] + if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(df_input_size[-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/data_processing/MANIQA/timm/models/layers/trace_utils.py b/data_processing/MANIQA/timm/models/layers/trace_utils.py new file mode 100644 index 0000000..8397072 --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/trace_utils.py @@ -0,0 +1,13 @@ +try: + from torch import _assert +except ImportError: + def _assert(condition: bool, message: str): + assert condition, message + + +def _float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) diff --git a/data_processing/MANIQA/timm/models/layers/weight_init.py b/data_processing/MANIQA/timm/models/layers/weight_init.py new file mode 100644 index 0000000..305a2fd --- /dev/null +++ b/data_processing/MANIQA/timm/models/layers/weight_init.py @@ -0,0 +1,89 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/data_processing/MANIQA/timm/models/levit.py b/data_processing/MANIQA/timm/models/levit.py new file mode 100644 index 0000000..4f400fd --- /dev/null +++ b/data_processing/MANIQA/timm/models/levit.py @@ -0,0 +1,563 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" + +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools +from copy import deepcopy +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple, get_act_layer +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), + **kwargs + } + + +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) + +__all__ = ['Levit'] + + +@register_model +def levit_128s(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_256(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_384(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): + super().__init__() + + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + h = self.dh + nh_kd * 2 + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + else: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = self.d * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) + + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = {} # per-device attention_biases cache + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + + NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems + w/ train scripts that don't take tuple outputs, + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, + hybrid_backbone=None, + down_ops=None, + act_layer='hard_swish', + attn_act_layer='hard_swish', + distillation=True, + use_conv=False, + drop_rate=0., + drop_path_rate=0.): + super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) + self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) + + self.blocks = [] + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), + drop_path_rate)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + self.blocks = nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def get_classifier(self): + if self.head_dist is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool='', distillation=None): + self.num_classes = num_classes + self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + if distillation is not None: + self.distillation = distillation + if self.distillation: + self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + def forward_features(self, x): + x = self.patch_embed(x) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) + #if fuse: + # utils.replace_batchnorm(model) + return model + diff --git a/data_processing/MANIQA/timm/models/mlp_mixer.py b/data_processing/MANIQA/timm/models/mlp_mixer.py new file mode 100644 index 0000000..727b655 --- /dev/null +++ b/data_processing/MANIQA/timm/models/mlp_mixer.py @@ -0,0 +1,659 @@ +""" MLP-Mixer, ResMLP, and gMLP in PyTorch + +This impl originally based on MLP-Mixer paper. + +Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py + +Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + +@article{tolstikhin2021, + title={MLP-Mixer: An all-MLP Architecture for Vision}, + author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, + Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey}, + journal={arXiv preprint arXiv:2105.01601}, + year={2021} +} + +Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP + +Code: https://github.com/facebookresearch/deit +Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 +@misc{touvron2021resmlp, + title={ResMLP: Feedforward networks for image classification with data-efficient training}, + author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and + Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou}, + year={2021}, + eprint={2105.03404}, +} + +Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 +@misc{liu2021pay, + title={Pay Attention to MLPs}, + author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le}, + year={2021}, + eprint={2105.08050}, +} + +A thank you to paper authors for releasing code and weights. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + mixer_s32_224=_cfg(), + mixer_s16_224=_cfg(), + mixer_b32_224=_cfg(), + mixer_b16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', + ), + mixer_b16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth', + num_classes=21843 + ), + mixer_l32_224=_cfg(), + mixer_l16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth', + ), + mixer_l16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', + num_classes=21843 + ), + + # Mixer ImageNet-21K-P pretraining + mixer_b16_224_miil_in21k=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + mixer_b16_224_miil=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), + + gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + gmixer_24_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_big_24_224_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_224_dino=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_224_dino=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + gmlp_ti16_224=_cfg(), + gmlp_s16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth', + ), + gmlp_b16_224=_cfg(), +) + + +class MixerBlock(nn.Module): + """ Residual Block w/ token mixing and channel MLPs + Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + def __init__( + self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] + self.norm1 = norm_layer(dim) + self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class Affine(nn.Module): + def __init__(self, dim): + super().__init__() + self.alpha = nn.Parameter(torch.ones((1, 1, dim))) + self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + + def forward(self, x): + return torch.addcmul(self.beta, self.alpha, x) + + +class ResBlock(nn.Module): + """ Residual MLP block w/ LayerScale and Affine 'norm' + + Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine, + act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm1 = norm_layer(dim) + self.linear_tokens = nn.Linear(seq_len, seq_len) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x))) + return x + + +class SpatialGatingUnit(nn.Module): + """ Spatial Gating Unit + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm): + super().__init__() + gate_dim = dim // 2 + self.norm = norm_layer(gate_dim) + self.proj = nn.Linear(seq_len, seq_len) + + def init_weights(self): + # special init for the projection gate, called as override by base model init + nn.init.normal_(self.proj.weight, std=1e-6) + nn.init.ones_(self.proj.bias) + + def forward(self, x): + u, v = x.chunk(2, dim=-1) + v = self.norm(v) + v = self.proj(v.transpose(-1, -2)) + return u * v.transpose(-1, -2) + + +class SpatialGatingBlock(nn.Module): + """ Residual Block w/ Spatial Gating + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm = norm_layer(dim) + sgu = partial(SpatialGatingUnit, seq_len=seq_len) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.mlp_channels(self.norm(x))) + return x + + +class MlpMixer(nn.Module): + + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + patch_size=16, + num_blocks=8, + embed_dim=512, + mlp_ratio=(0.5, 4.0), + block_layer=MixerBlock, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + ): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) + # FIXME drop_path (stochastic depth scaling rule or all the same?) + self.blocks = nn.Sequential(*[ + block_layer( + embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) + for _ in range(num_blocks)]) + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + """ Mixer weight initialization (trying to match Flax defaults) + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # like MLP init in vit (my original init) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + # NOTE if a parent module contains init_weights method, it can override the init of the + # child modules as this will be called in depth-first order. + module.init_weights() + + +def checkpoint_filter_fn(state_dict, model): + """ Remap checkpoints if needed """ + if 'patch_embed.proj.weight' in state_dict: + # Remap FB ResMlp models -> timm + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('patch_embed.', 'stem.') + k = k.replace('attn.', 'linear_tokens.') + k = k.replace('mlp.', 'mlp_channels.') + k = k.replace('gamma_', 'ls') + if k.endswith('.alpha') or k.endswith('.beta'): + v = v.reshape(1, 1, -1) + out_dict[k] = v + return out_dict + return state_dict + + +def _create_mixer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for MLP-Mixer models.') + + model = build_model_with_cfg( + MlpMixer, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def mixer_s32_224(pretrained=False, **kwargs): + """ Mixer-S/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) + model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_s16_224(pretrained=False, **kwargs): + """ Mixer-S/16 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) + model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b32_224(pretrained=False, **kwargs): + """ Mixer-B/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l32_224(pretrained=False, **kwargs): + """ Mixer-L/32 224x224. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224_in21k(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_miil(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmixer_12_224(pretrained=False, **kwargs): + """ Glu-Mixer-12 224x224 + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmixer_24_224(pretrained=False, **kwargs): + """ Glu-Mixer-24 224x224 + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_distilled_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_distilled_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_224_dino(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + + Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_224_dino(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + + Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_ti16_224(pretrained=False, **kwargs): + """ gMLP-Tiny + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_s16_224(pretrained=False, **kwargs): + """ gMLP-Small + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_b16_224(pretrained=False, **kwargs): + """ gMLP-Base + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) + return model diff --git a/data_processing/MANIQA/timm/models/mobilenetv3.py b/data_processing/MANIQA/timm/models/mobilenetv3.py new file mode 100644 index 0000000..8047412 --- /dev/null +++ b/data_processing/MANIQA/timm/models/mobilenetv3.py @@ -0,0 +1,679 @@ +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2019, Ross Wightman +""" +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid +from .registry import register_model + +__all__ = ['MobileNetV3', 'MobileNetV3Features'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mobilenetv3_large_075': _cfg(url=''), + 'mobilenetv3_large_100': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), + 'mobilenetv3_large_100_miil': _cfg( + interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'), + 'mobilenetv3_large_100_miil_in21k': _cfg( + interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221), + + 'mobilenetv3_small_050': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', + interpolation='bicubic'), + 'mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth', + interpolation='bicubic'), + 'mobilenetv3_small_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth', + interpolation='bicubic'), + + 'mobilenetv3_rw': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + + 'tf_mobilenetv3_large_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'fbnetv3_b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_g': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth', + input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95), + + "lcnet_035": _cfg(), + "lcnet_050": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth', + interpolation='bicubic', + ), + "lcnet_075": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth', + interpolation='bicubic', + ), + "lcnet_100": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth', + interpolation='bicubic', + ), + "lcnet_150": _cfg(), +} + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: `Searching for MobileNetV3` - https://arxiv.org/abs/1905.02244 + + Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include: + * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class) + * FBNet-V3 - https://arxiv.org/abs/2006.02049 + * LCNet - https://arxiv.org/abs/2109.15099 + """ + + def __init__( + self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280, + head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, + round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): + super(MobileNetV3, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + num_pooled_chs = head_chs * self.global_pool.feat_mult() + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + super(MobileNetV3Features, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, + drop_path_rate=drop_path_rate, feature_location=feature_location) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_mnv3(variant, pretrained=False, **kwargs): + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'), + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + fix_stem=channel_multiplier < 0.75, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNetV3 + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` + - https://arxiv.org/abs/2006.02049 + FIXME untested, this is a preliminary impl of some FBNet-V3 variants. + """ + vl = variant.split('_')[-1] + if vl in ('a', 'b'): + stem_size = 16 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], + ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], + ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], + ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], + ['cn_r1_k1_s1_c1344'], + ] + elif vl == 'd': + stem_size = 24 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], + ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], + ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], + ['cn_r1_k1_s1_c1440'], + ] + elif vl == 'g': + stem_size = 32 + arch_def = [ + ['ds_r3_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], + ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], + ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], + ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], + ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], + ['cn_r1_k1_s1_c1728'], + ] + else: + raise NotImplemented + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn) + act_layer = resolve_act_layer(kwargs, 'hard_swish') + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1984, + head_bias=False, + stem_size=stem_size, + round_chs_fn=round_chs_fn, + se_from_exp=False, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ LCNet + Essentially a MobileNet-V3 crossed with a MobileNet-V1 + + Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['dsa_r1_k3_s1_c32'], + # stage 1, 112x112 in + ['dsa_r2_k3_s2_c64'], + # stage 2, 56x56 in + ['dsa_r2_k3_s2_c128'], + # stage 3, 28x28 in + ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'], + # stage 4, 14x14in + ['dsa_r4_k5_s1_c256'], + # stage 5, 14x14in + ['dsa_r2_k5_s2_c512_se0.25'], + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU), + num_features=1280, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +@register_model +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100_miil(pretrained=False, **kwargs): + """ MobileNet V3 + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs): + """ MobileNet V3, 21k pretraining + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_050(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet V3 """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_b(pretrained=False, **kwargs): + """ FBNetV3-B """ + model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_d(pretrained=False, **kwargs): + """ FBNetV3-D """ + model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_g(pretrained=False, **kwargs): + """ FBNetV3-G """ + model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_035(pretrained=False, **kwargs): + """ PP-LCNet 0.35""" + model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_050(pretrained=False, **kwargs): + """ PP-LCNet 0.5""" + model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_075(pretrained=False, **kwargs): + """ PP-LCNet 1.0""" + model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_100(pretrained=False, **kwargs): + """ PP-LCNet 1.0""" + model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def lcnet_150(pretrained=False, **kwargs): + """ PP-LCNet 1.5""" + model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/nasnet.py b/data_processing/MANIQA/timm/models/nasnet.py new file mode 100644 index 0000000..2afe82c --- /dev/null +++ b/data_processing/MANIQA/timm/models/nasnet.py @@ -0,0 +1,567 @@ +""" NasNet-A (Large) + nasnetalarge implementation grabbed from Cadene's pretrained models + https://github.com/Cadene/pretrained-models.pytorch +""" +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['NASNetALarge'] + +default_cfgs = { + 'nasnetalarge': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +} + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=0) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1) + self.act_2 = nn.ReLU(inplace=True) + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=pad_type) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_size, num_channels=42, pad_type=''): + super(CellStem0, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_size, num_channels, pad_type=''): + super(CellStem1, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.act(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2(x_relu) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(FirstCell, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_relu = self.act(x_prev) + x_path1 = self.path_1(x_relu) + x_path2 = self.path_2(x_relu) + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetALarge(nn.Module): + """NASNetALarge (6 @ 4032) """ + + def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, + num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): + super(NASNetALarge, self).__init__() + self.num_classes = num_classes + self.stem_size = stem_size + self.num_features = num_features + self.channel_multiplier = channel_multiplier + self.drop_rate = drop_rate + assert output_stride == 32 + + channels = self.num_features // 24 + # 24 is default value for the architecture + + self.conv0 = ConvBnAct( + in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + + self.cell_stem_0 = CellStem0( + self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) + self.cell_stem_1 = CellStem1( + self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type) + + self.cell_0 = FirstCell( + in_chs_left=channels, out_chs_left=channels // 2, + in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_1 = NormalCell( + in_chs_left=2 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_2 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_3 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_4 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_5 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + + self.reduction_cell_0 = ReductionCell0( + in_chs_left=6 * channels, out_chs_left=2 * channels, + in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_6 = FirstCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_7 = NormalCell( + in_chs_left=8 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_8 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_9 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_10 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_11 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + + self.reduction_cell_1 = ReductionCell1( + in_chs_left=12 * channels, out_chs_left=4 * channels, + in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_12 = FirstCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_13 = NormalCell( + in_chs_left=16 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_14 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_15 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_16 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_17 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.act = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv0'), + dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'), + dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'), + dict(num_chs=4032, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv0 = self.conv0(x) + + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + x_cell_4 = self.cell_4(x_cell_3, x_cell_2) + x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + x_cell_10 = self.cell_10(x_cell_9, x_cell_8) + x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + x_cell_16 = self.cell_16(x_cell_15, x_cell_14) + x_cell_17 = self.cell_17(x_cell_16, x_cell_15) + x = self.act(x_cell_17) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_nasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + NASNetALarge, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def nasnetalarge(pretrained=False, **kwargs): + """NASNet-A large model architecture. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_nasnet('nasnetalarge', pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/nest.py b/data_processing/MANIQA/timm/models/nest.py new file mode 100644 index 0000000..22cf609 --- /dev/null +++ b/data_processing/MANIQA/timm/models/nest.py @@ -0,0 +1,465 @@ +""" Nested Transformer (NesT) in PyTorch + +A PyTorch implement of Aggregating Nested Transformers as described in: + +'Aggregating Nested Transformers' + - https://arxiv.org/abs/2105.12723 + +The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights +have been converted with convert/convert_nest_flax.py + +Acknowledgments: +* The paper authors for sharing their research, code, and model weights +* Ross Wightman's existing code off which I based this + +Copyright 2021 Alexander Soare +""" + +import collections.abc +import logging +import math +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers import _assert +from .layers import create_conv2d, create_pool2d, to_ntuple +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14], + 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # (weights from official Google JAX impl) + 'nest_base': _cfg(), + 'nest_small': _cfg(), + 'nest_tiny': _cfg(), + 'jx_nest_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'), + 'jx_nest_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'), + 'jx_nest_tiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'), +} + + +class Attention(nn.Module): + """ + This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with + an extra "image block" dim + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + """ + x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim) + """ + B, T, N, C = x.shape + # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) + qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, T, N, C'), permute -> (B, T, N, C', H) + x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x # (B, T, N, C) + + +class TransformerLayer(nn.Module): + """ + This is much like `.vision_transformer.Block` but: + - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") + - Uses modified Attention layer that handles the "block" dimension + """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + y = self.norm1(x) + x = x + self.drop_path(self.attn(y)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ConvPool(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): + super().__init__() + self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) + self.norm = norm_layer(out_channels) + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type) + + def forward(self, x): + """ + x is expected to have shape (B, C, H, W) + """ + _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') + x = self.conv(x) + # Layer norm done over channel dim only + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.pool(x) + return x # (B, C, H//2, W//2) + + +def blockify(x, block_size: int): + """image to blocks + Args: + x (Tensor): with shape (B, H, W, C) + block_size (int): edge length of a single square block in units of H, W + """ + B, H, W, C = x.shape + _assert(H % block_size == 0, '`block_size` must divide input height evenly') + _assert(W % block_size == 0, '`block_size` must divide input width evenly') + grid_height = H // block_size + grid_width = W // block_size + x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) + x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C) + return x # (B, T, N, C) + + +@register_notrace_function # reason: int receives Proxy +def deblockify(x, block_size: int): + """blocks to image + Args: + x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block + block_size (int): edge length of a single square block in units of desired H, W + """ + B, T, _, C = x.shape + grid_size = int(math.sqrt(T)) + height = width = grid_size * block_size + x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) + x = x.transpose(2, 3).reshape(B, height, width, C) + return x # (B, H, W, C) + + +class NestLevel(nn.Module): + """ Single hierarchical level of a Nested Transformer + """ + def __init__( + self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None, + mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[], + norm_layer=None, act_layer=None, pad_type=''): + super().__init__() + self.block_size = block_size + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + + if prev_embed_dim is not None: + self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type) + else: + self.pool = nn.Identity() + + # Transformer encoder + if len(drop_path_rates): + assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' + self.transformer_encoder = nn.Sequential(*[ + TransformerLayer( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + + def forward(self, x): + """ + expects x as (B, C, H, W) + """ + x = self.pool(x) + x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer + x = blockify(x, self.block_size) # (B, T, N, C') + x = x + self.pos_embed + x = self.transformer_encoder(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') + # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage + return x.permute(0, 3, 1, 2) # (B, C, H', W') + + +class Nest(nn.Module): + """ Nested Transformer (NesT) + + A PyTorch impl of : `Aggregating Nested Transformers` + - https://arxiv.org/abs/2105.12723 + """ + + def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, + pad_type='', weight_init='', global_pool='avg'): + """ + Args: + img_size (int, tuple): input image size + in_chans (int): number of input channels + patch_size (int): patch size + num_levels (int): number of block hierarchies (T_d in the paper) + embed_dims (int, tuple): embedding dimensions of each level + num_heads (int, tuple): number of attention heads for each level + depths (int, tuple): number of transformer layers for each level + num_classes (int): number of classes for classification head + mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer for transformer layers + act_layer: (nn.Module): activation layer in MLP of transformer layers + pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME + weight_init: (str): weight init scheme + global_pool: (str): type of pooling operation to apply to final feature map + + Notes: + - Default values follow NesT-B from the original Jax code. + - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`. + - For those following the paper, Table A1 may have errors! + - https://github.com/google-research/nested-transformer/issues/2 + """ + super().__init__() + + for param_name in ['embed_dims', 'num_heads', 'depths']: + param_value = locals()[param_name] + if isinstance(param_value, collections.abc.Sequence): + assert len(param_value) == num_levels, f'Require `len({param_name}) == num_levels`' + + embed_dims = to_ntuple(num_levels)(embed_dims) + num_heads = to_ntuple(num_levels)(num_heads) + depths = to_ntuple(num_levels)(depths) + self.num_classes = num_classes + self.num_features = embed_dims[-1] + self.feature_info = [] + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.drop_rate = drop_rate + self.num_levels = num_levels + if isinstance(img_size, collections.abc.Sequence): + assert img_size[0] == img_size[1], 'Model only handles square inputs' + img_size = img_size[0] + assert img_size % patch_size == 0, '`patch_size` must divide `img_size` evenly' + self.patch_size = patch_size + + # Number of blocks at each level + self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist() + assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \ + 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' + + # Block edge size in units of patches + # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the + # number of blocks along edge of image + self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0])) + + # Patch embedding + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False) + self.num_patches = self.patch_embed.num_patches + self.seq_length = self.num_patches // self.num_blocks[0] + + # Build up each hierarchical level + levels = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + prev_dim = None + curr_stride = 4 + for i in range(len(self.num_blocks)): + dim = embed_dims[i] + levels.append(NestLevel( + self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim, + mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type)) + self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')] + prev_dim = dim + curr_stride *= 2 + self.levels = nn.Sequential(*levels) + + # Final normalization layer + self.norm = norm_layer(embed_dims[-1]) + + # Classifier + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + for level in self.levels: + trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) + named_apply(partial(_init_nest_weights, head_bias=head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {f'level.{i}.pos_embed' for i in range(len(self.levels))} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + """ x shape (B, C, H, W) + """ + x = self.patch_embed(x) + x = self.levels(x) + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + def forward(self, x): + """ x shape (B, C, H, W) + """ + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.head(x) + + +def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.): + """ NesT weight initialization + Can replicate Jax implementation. Otherwise follows vision_transformer.py + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + trunc_normal_(module.weight, std=.02, a=-2, b=2) + nn.init.constant_(module.bias, head_bias) + else: + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def resize_pos_embed(posemb, posemb_new): + """ + Rescale the grid of position embeddings when loading from state_dict + Expected shape of position embeddings is (1, T, N, C), and considers only square images + """ + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + seq_length_old = posemb.shape[2] + num_blocks_new, seq_length_new = posemb_new.shape[1:3] + size_new = int(math.sqrt(num_blocks_new*seq_length_new)) + # First change to (1, C, H, W) + posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False) + # Now change to new (1, T, N, C) + posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new))) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ resize positional embeddings of pretrained weights """ + pos_embed_keys = [k for k in state_dict.keys() if k.startswith('pos_embed_')] + for k in pos_embed_keys: + if state_dict[k].shape != getattr(model, k).shape: + state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k)) + return state_dict + + +def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + model = build_model_with_cfg( + Nest, variant, pretrained, + default_cfg=default_cfg, + feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224 + """ + model_kwargs = dict( + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) + model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224 + """ + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) + model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224 + """ + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) + model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) + model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) + model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) + model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/nfnet.py b/data_processing/MANIQA/timm/models/nfnet.py new file mode 100644 index 0000000..973cbd6 --- /dev/null +++ b/data_processing/MANIQA/timm/models/nfnet.py @@ -0,0 +1,968 @@ +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models + +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from dataclasses import dataclass, field +from collections import OrderedDict +from typing import Tuple, Optional +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_module +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ + get_act_layer, get_act_fn, get_attn, make_divisible + + +def _dcfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + dm_nfnet_f0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth', + pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9), + dm_nfnet_f1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91), + dm_nfnet_f2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92), + dm_nfnet_f3=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94), + dm_nfnet_f4=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth', + pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951), + dm_nfnet_f5=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth', + pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954), + dm_nfnet_f6=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth', + pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956), + + nfnet_f0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nfnet_f1=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), + nfnet_f2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), + nfnet_f3=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), + nfnet_f4=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), + nfnet_f5=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), + nfnet_f6=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), + nfnet_f7=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), + + nfnet_f0s=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nfnet_f1s=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), + nfnet_f2s=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), + nfnet_f3s=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), + nfnet_f4s=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), + nfnet_f5s=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), + nfnet_f6s=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), + nfnet_f7s=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), + + nfnet_l0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), + eca_nfnet_l0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', + hf_hub='timm/eca_nfnet_l0', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), + eca_nfnet_l1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + eca_nfnet_l2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), + eca_nfnet_l3=_dcfg( + url='', + pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0), + + nf_regnet_b0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), + nf_regnet_b1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec + nf_regnet_b2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'), + nf_regnet_b3=_dcfg( + url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'), + nf_regnet_b4=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'), + nf_regnet_b5=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'), + + nf_resnet26=_dcfg(url='', first_conv='stem.conv'), + nf_resnet50=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'), + nf_resnet101=_dcfg(url='', first_conv='stem.conv'), + + nf_seresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet101=_dcfg(url='', first_conv='stem.conv'), + + nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'), +) + + +@dataclass +class NfCfg: + depths: Tuple[int, int, int, int] + channels: Tuple[int, int, int, int] + alpha: float = 0.2 + stem_type: str = '3x3' + stem_chs: Optional[int] = None + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None + attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 + ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + gamma_in_act: bool = False + same_padding: bool = False + std_conv_eps: float = 1e-5 + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False + act_layer: str = 'silu' + + +def _nfres_cfg( + depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, + group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, + num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + +def _nfnet_cfg( + depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., + act_layer='gelu', attn_layer='se', attn_kwargs=None): + num_features = int(channels[-1] * feat_mult) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, + bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, + attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, + bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) + return cfg + + + +model_cfgs = dict( + # NFNet-F models w/ GELU compatible with DeepMind weights + dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), + dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)), + dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)), + dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)), + dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)), + dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), + dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), + + # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU) + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), + nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), + nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), + nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), + nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), + nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), + nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), + nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + + # Experimental 'light' versions of NFNet-F that are little leaner + nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), + eca_nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l1=_nfnet_cfg( + depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + # FIXME add B6-B8 + + # ResNet (preact, D style deep stem/avg down) defs + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), + +) + + +class GammaAct(nn.Module): + def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False): + super().__init__() + self.act_fn = get_act_fn(act_type) + self.gamma = gamma + self.inplace = inplace + + def forward(self, x): + return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) + + +def act_with_gamma(act_type, gamma: float = 1.): + def _create(inplace=False): + return GammaAct(act_type, gamma=gamma, inplace=inplace) + return _create + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): + """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + + def forward(self, x): + return self.conv(self.pool(x)) + + +@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 +class NormFreeBlock(nn.Module): + """Normalization-Free pre-activation block. + """ + + def __init__( + self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, + alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, + skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + out_chs = out_chs or in_chs + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size + if group_size and group_size % ch_div == 0: + mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error + self.alpha = alpha + self.beta = beta + self.attn_gain = attn_gain + + if in_chs != out_chs or stride != 1 or dilation != first_dilation: + self.downsample = DownsampleAvg( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) + else: + self.downsample = None + + self.act1 = act_layer() + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.act2 = act_layer(inplace=True) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 + else: + self.attn = None + self.act3 = act_layer() + self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None + + def forward(self, x): + out = self.act1(x) * self.beta + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(out) + + # residual branch + out = self.conv1(out) + out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) + if self.attn is not None: + out = self.attn_gain * self.attn(out) + out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) + out = self.drop_path(out) + + if self.skipinit_gain is not None: + out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD + out = out * self.alpha + shortcut + return out + + +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True): + stem_stride = 2 + stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv') + stem = OrderedDict() + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + if 'deep' in stem_type: + if 'quad' in stem_type: + # 4 deep conv stack as in NFNet-F models + assert not 'pool' in stem_type + stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3') + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets + strides = (2, 1, 1) + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2') + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i + 2}'] = act_layer(inplace=True) + in_chs = c + elif '3x3' in stem_type: + # 3x3 stem conv as in RegNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) + else: + # 7x7 stem conv as in ResNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + + if 'pool' in stem_type: + stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) + stem_stride = 4 + + return nn.Sequential(stem), stem_stride, stem_feature + + +# from https://github.com/deepmind/deepmind-research/tree/master/nfnets +_nonlin_gamma = dict( + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, +) + + +class NormFreeNet(nn.Module): + """ Normalization-Free Network + + As described in : + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + and + `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 + + This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and + the (preact) ResNet models described earlier in the paper. + + There are a few differences: + * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy), + this changes channel dim and param counts slightly from the paper models + * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance + impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. + * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but + apply it in each activation. This is slightly slower, numerically different, but matches official impl. + * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput + for what it is/does. Approx 8-10% throughput loss. + """ + def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d + if cfg.gamma_in_act: + act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps) + else: + act_layer = get_act_layer(cfg.act_layer) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps) + attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + + stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) + + self.feature_info = [stem_feat] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + prev_chs = stem_chs + net_stride = stem_stride + dilation = 1 + expected_var = 1.0 + stages = [] + for stage_idx, stage_depth in enumerate(cfg.depths): + stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 + if net_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx in range(cfg.depths[stage_idx]): + first_block = block_idx == 0 and stage_idx == 0 + out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) + blocks += [NormFreeBlock( + in_chs=prev_chs, out_chs=out_chs, + alpha=cfg.alpha, + beta=1. / expected_var ** 0.5, + stride=stride if block_idx == 0 else 1, + dilation=dilation, + first_dilation=first_dilation, + group_size=cfg.group_size, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, + ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, + attn_layer=attn_layer, + attn_gain=cfg.attn_gain, + act_layer=act_layer, + conv_layer=conv_layer, + drop_path_rate=drop_path_rates[stage_idx][block_idx], + )] + if block_idx == 0: + expected_var = 1. # expected var is reset after first block of each stage + expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance + first_dilation = dilation + prev_chs = out_chs + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) + + if cfg.num_features: + # The paper NFRegNet models have an EfficientNet-like final head convolution. + self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) + self.final_conv = conv_layer(prev_chs, self.num_features, 1) + self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv') + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.final_act = act_layer(inplace=cfg.num_features > 0) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + for n, m in self.named_modules(): + if 'fc' in n and isinstance(m, nn.Linear): + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.final_conv(x) + x = self.final_act(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_normfreenet(variant, pretrained=False, **kwargs): + model_cfg = model_cfgs[variant] + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + NormFreeNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfg, + feature_cfg=feature_cfg, + **kwargs) + + +@register_model +def dm_nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7(pretrained=False, **kwargs): + """ NFNet-F7 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0s(pretrained=False, **kwargs): + """ NFNet-F0 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1s(pretrained=False, **kwargs): + """ NFNet-F1 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2s(pretrained=False, **kwargs): + """ NFNet-F2 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3s(pretrained=False, **kwargs): + """ NFNet-F3 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4s(pretrained=False, **kwargs): + """ NFNet-F4 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5s(pretrained=False, **kwargs): + """ NFNet-F5 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6s(pretrained=False, **kwargs): + """ NFNet-F6 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7s(pretrained=False, **kwargs): + """ NFNet-F7 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0(pretrained=False, **kwargs): + """ NFNet-L0b w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l0(pretrained=False, **kwargs): + """ ECA-NFNet-L0 w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l1(pretrained=False, **kwargs): + """ ECA-NFNet-L1 w/ SiLU + My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs): + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs): + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B0 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B1 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B2 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B3 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B4 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B5 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet26(pretrained=False, **kwargs): + """ Normalization-Free ResNet-26 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet50(pretrained=False, **kwargs): + """ Normalization-Free ResNet-50 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet101(pretrained=False, **kwargs): + """ Normalization-Free ResNet-101 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet26(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet26 + """ + return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet50(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet50 + """ + return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet101(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet101 + """ + return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet26(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet26 + """ + return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet50(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet50 + """ + return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet101(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet101 + """ + return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/pit.py b/data_processing/MANIQA/timm/models/pit.py new file mode 100644 index 0000000..460824e --- /dev/null +++ b/data_processing/MANIQA/timm/models/pit.py @@ -0,0 +1,384 @@ +""" Pooling-based Vision Transformer (PiT) in PyTorch + +A PyTorch implement of Pooling-based Vision Transformers as described in +'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302 + +This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below. + +Modifications for timm by / Copyright 2020 Ross Wightman +""" +# PiT +# Copyright 2021-present NAVER Corp. +# Apache License v2.0 + +import math +import re +from copy import deepcopy +from functools import partial +from typing import Tuple + +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import trunc_normal_, to_2tuple +from .registry import register_model +from .vision_transformer import Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # deit models (FB weights) + 'pit_ti_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'), + 'pit_xs_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'), + 'pit_s_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'), + 'pit_b_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), + 'pit_ti_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth', + classifier=('head', 'head_dist')), + 'pit_xs_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth', + classifier=('head', 'head_dist')), + 'pit_s_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth', + classifier=('head', 'head_dist')), + 'pit_b_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth', + classifier=('head', 'head_dist')), +} + + +class SequentialTuple(nn.Sequential): + """ This module exists to work around torchscript typing issues list -> list""" + def __init__(self, *args): + super(SequentialTuple, self).__init__(*args) + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + for module in self: + x = module(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): + super(Transformer, self).__init__() + self.layers = nn.ModuleList([]) + embed_dim = base_dim * heads + + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_prob[i], + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + for i in range(depth)]) + + self.pool = pool + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + x, cls_tokens = x + B, C, H, W = x.shape + token_length = cls_tokens.shape[1] + + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + cls_tokens = x[:, :token_length] + x = x[:, token_length:] + x = x.transpose(1, 2).reshape(B, C, H, W) + + if self.pool is not None: + x, cls_tokens = self.pool(x, cls_tokens) + return x, cls_tokens + + +class ConvHeadPooling(nn.Module): + def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): + super(ConvHeadPooling, self).__init__() + + self.conv = nn.Conv2d( + in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride, + padding_mode=padding_mode, groups=in_feature) + self.fc = nn.Linear(in_feature, out_feature) + + def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.conv(x) + cls_token = self.fc(cls_token) + + return x, cls_token + + +class ConvEmbedding(nn.Module): + def __init__(self, in_channels, out_channels, patch_size, stride, padding): + super(ConvEmbedding, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True) + + def forward(self, x): + x = self.conv(x) + return x + + +class PoolingVisionTransformer(nn.Module): + """ Pooling-based Vision Transformer + + A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers' + - https://arxiv.org/abs/2103.16302 + """ + def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, + mlp_ratio, num_classes=1000, in_chans=3, distilled=False, + attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): + super(PoolingVisionTransformer, self).__init__() + + padding = 0 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1) + width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1) + + self.base_dims = base_dims + self.heads = heads + self.num_classes = num_classes + self.num_tokens = 2 if distilled else 1 + + self.patch_size = patch_size + self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width)) + self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) + + self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0])) + self.pos_drop = nn.Dropout(p=drop_rate) + + transformers = [] + # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)] + for stage in range(len(depth)): + pool = None + if stage < len(heads) - 1: + pool = ConvHeadPooling( + base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2) + transformers += [Transformer( + base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool, + drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage]) + ] + self.transformers = SequentialTuple(*transformers) + self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) + self.num_features = self.embed_dim = base_dims[-1] * heads[-1] + + # Classifier head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + if self.head_dist is not None: + return self.head, self.head_dist + else: + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x, cls_tokens = self.transformers((x, cls_tokens)) + cls_tokens = self.norm(cls_tokens) + if self.head_dist is not None: + return cls_tokens[:, 0], cls_tokens[:, 1] + else: + return cls_tokens[:, 0] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + return (x + x_dist) / 2 + else: + return self.head(x) + + +def checkpoint_filter_fn(state_dict, model): + """ preprocess checkpoints """ + out_dict = {} + p_blocks = re.compile(r'pools\.(\d)\.') + for k, v in state_dict.items(): + # FIXME need to update resize for PiT impl + # if k == 'pos_embed' and v.shape != model.pos_embed.shape: + # # To resize pos embedding when using model at different size from pretrained weights + # v = resize_pos_embed(v, model.pos_embed) + k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k) + out_dict[k] = v + return out_dict + + +def _create_pit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + PoolingVisionTransformer, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def pit_b_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_b_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_s_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_xs_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_ti_224', pretrained, **model_kwargs) + + +@register_model +def pit_b_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs) \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/pnasnet.py b/data_processing/MANIQA/timm/models/pnasnet.py new file mode 100644 index 0000000..9991815 --- /dev/null +++ b/data_processing/MANIQA/timm/models/pnasnet.py @@ -0,0 +1,350 @@ +""" + pnasnet5large implementation grabbed from Cadene's pretrained models + Additional credit to https://github.com/creafz + + https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py + +""" +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['PNASNet5Large'] + +default_cfgs = { + 'pnasnet5large': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +} + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=padding) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=padding) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + self.act_2 = nn.ReLU() + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=padding) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class FactorizedReduction(nn.Module): + + def __init__(self, in_channels, out_channels, padding=''): + super(FactorizedReduction, self).__init__() + self.act = nn.ReLU() + self.path_1 = nn.Sequential(OrderedDict([ + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.path_2 = nn.Sequential(OrderedDict([ + ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x_path1 = self.path_1(x) + x_path2 = self.path_2(x) + out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + return out + + +class CellBase(nn.Module): + + def cell_forward(self, x_left, x_right): + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) + x_comb_iter_3_right = self.comb_iter_3_right(x_right) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_left) + if self.comb_iter_4_right is not None: + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + else: + x_comb_iter_4_right = x_right + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem0(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(CellStem0, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_0_right = nn.Sequential(OrderedDict([ + ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), + ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), + ])) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) + self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) + + def forward(self, x_left): + x_right = self.conv_1x1(x_left) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class Cell(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', + is_reduction=False, match_prev_layer_dims=False): + super(Cell, self).__init__() + + # If `is_reduction` is set to `True` stride 2 is used for + # convolution and pooling layers to reduce the spatial size of + # the output of a cell approximately by a factor of 2. + stride = 2 if is_reduction else 1 + + # If `match_prev_layer_dimensions` is set to `True` + # `FactorizedReduction` is used to reduce the spatial size + # of the left input of a cell approximately by a factor of 2. + self.match_prev_layer_dimensions = match_prev_layer_dims + if match_prev_layer_dims: + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) + else: + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) + self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) + if is_reduction: + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) + else: + self.comb_iter_4_right = None + + def forward(self, x_left, x_right): + x_left = self.conv_prev_1x1(x_left) + x_right = self.conv_1x1(x_right) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class PNASNet5Large(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): + super(PNASNet5Large, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.num_features = 4320 + assert output_stride == 32 + + self.conv_0 = ConvBnAct( + in_chans, 96, kernel_size=3, stride=2, padding=0, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + + self.cell_stem_0 = CellStem0( + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) + + self.cell_stem_1 = Cell( + in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, + match_prev_layer_dims=True, is_reduction=True) + self.cell_0 = Cell( + in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_1 = Cell( + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_2 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_3 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + + self.cell_4 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, + is_reduction=True) + self.cell_5 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_6 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + self.cell_7 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + + self.cell_8 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, + is_reduction=True) + self.cell_9 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_10 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.cell_11 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.act = nn.ReLU() + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv_0'), + dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'), + dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'), + dict(num_chs=4320, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv_0 = self.conv_0(x) + x_stem_0 = self.cell_stem_0(x_conv_0) + x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) + x_cell_0 = self.cell_0(x_stem_0, x_stem_1) + x_cell_1 = self.cell_1(x_stem_1, x_cell_0) + x_cell_2 = self.cell_2(x_cell_0, x_cell_1) + x_cell_3 = self.cell_3(x_cell_1, x_cell_2) + x_cell_4 = self.cell_4(x_cell_2, x_cell_3) + x_cell_5 = self.cell_5(x_cell_3, x_cell_4) + x_cell_6 = self.cell_6(x_cell_4, x_cell_5) + x_cell_7 = self.cell_7(x_cell_5, x_cell_6) + x_cell_8 = self.cell_8(x_cell_6, x_cell_7) + x_cell_9 = self.cell_9(x_cell_7, x_cell_8) + x_cell_10 = self.cell_10(x_cell_8, x_cell_9) + x_cell_11 = self.cell_11(x_cell_9, x_cell_10) + x = self.act(x_cell_11) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_pnasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + PNASNet5Large, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def pnasnet5large(pretrained=False, **kwargs): + r"""PNASNet-5 model architecture from the + `"Progressive Neural Architecture Search" + `_ paper. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/pruned/ecaresnet101d_pruned.txt b/data_processing/MANIQA/timm/models/pruned/ecaresnet101d_pruned.txt new file mode 100644 index 0000000..2589b2f --- /dev/null +++ b/data_processing/MANIQA/timm/models/pruned/ecaresnet101d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[45, 64, 1, 1]***layer1.0.bn1.weight:[45]***layer1.0.conv2.weight:[25, 45, 3, 3]***layer1.0.bn2.weight:[25]***layer1.0.conv3.weight:[26, 25, 1, 1]***layer1.0.bn3.weight:[26]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[26, 64, 1, 1]***layer1.0.downsample.2.weight:[26]***layer1.1.conv1.weight:[53, 26, 1, 1]***layer1.1.bn1.weight:[53]***layer1.1.conv2.weight:[20, 53, 3, 3]***layer1.1.bn2.weight:[20]***layer1.1.conv3.weight:[26, 20, 1, 1]***layer1.1.bn3.weight:[26]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[60, 26, 1, 1]***layer1.2.bn1.weight:[60]***layer1.2.conv2.weight:[27, 60, 3, 3]***layer1.2.bn2.weight:[27]***layer1.2.conv3.weight:[26, 27, 1, 1]***layer1.2.bn3.weight:[26]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[81, 26, 1, 1]***layer2.0.bn1.weight:[81]***layer2.0.conv2.weight:[24, 81, 3, 3]***layer2.0.bn2.weight:[24]***layer2.0.conv3.weight:[142, 24, 1, 1]***layer2.0.bn3.weight:[142]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[142, 26, 1, 1]***layer2.0.downsample.2.weight:[142]***layer2.1.conv1.weight:[93, 142, 1, 1]***layer2.1.bn1.weight:[93]***layer2.1.conv2.weight:[49, 93, 3, 3]***layer2.1.bn2.weight:[49]***layer2.1.conv3.weight:[142, 49, 1, 1]***layer2.1.bn3.weight:[142]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[102, 142, 1, 1]***layer2.2.bn1.weight:[102]***layer2.2.conv2.weight:[54, 102, 3, 3]***layer2.2.bn2.weight:[54]***layer2.2.conv3.weight:[142, 54, 1, 1]***layer2.2.bn3.weight:[142]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[122, 142, 1, 1]***layer2.3.bn1.weight:[122]***layer2.3.conv2.weight:[78, 122, 3, 3]***layer2.3.bn2.weight:[78]***layer2.3.conv3.weight:[142, 78, 1, 1]***layer2.3.bn3.weight:[142]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[101, 142, 1, 1]***layer3.0.bn1.weight:[101]***layer3.0.conv2.weight:[25, 101, 3, 3]***layer3.0.bn2.weight:[25]***layer3.0.conv3.weight:[278, 25, 1, 1]***layer3.0.bn3.weight:[278]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[278, 142, 1, 1]***layer3.0.downsample.2.weight:[278]***layer3.1.conv1.weight:[239, 278, 1, 1]***layer3.1.bn1.weight:[239]***layer3.1.conv2.weight:[160, 239, 3, 3]***layer3.1.bn2.weight:[160]***layer3.1.conv3.weight:[278, 160, 1, 1]***layer3.1.bn3.weight:[278]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[234, 278, 1, 1]***layer3.2.bn1.weight:[234]***layer3.2.conv2.weight:[156, 234, 3, 3]***layer3.2.bn2.weight:[156]***layer3.2.conv3.weight:[278, 156, 1, 1]***layer3.2.bn3.weight:[278]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[250, 278, 1, 1]***layer3.3.bn1.weight:[250]***layer3.3.conv2.weight:[176, 250, 3, 3]***layer3.3.bn2.weight:[176]***layer3.3.conv3.weight:[278, 176, 1, 1]***layer3.3.bn3.weight:[278]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[253, 278, 1, 1]***layer3.4.bn1.weight:[253]***layer3.4.conv2.weight:[191, 253, 3, 3]***layer3.4.bn2.weight:[191]***layer3.4.conv3.weight:[278, 191, 1, 1]***layer3.4.bn3.weight:[278]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[251, 278, 1, 1]***layer3.5.bn1.weight:[251]***layer3.5.conv2.weight:[175, 251, 3, 3]***layer3.5.bn2.weight:[175]***layer3.5.conv3.weight:[278, 175, 1, 1]***layer3.5.bn3.weight:[278]***layer3.5.se.conv.weight:[1, 1, 5]***layer3.6.conv1.weight:[230, 278, 1, 1]***layer3.6.bn1.weight:[230]***layer3.6.conv2.weight:[128, 230, 3, 3]***layer3.6.bn2.weight:[128]***layer3.6.conv3.weight:[278, 128, 1, 1]***layer3.6.bn3.weight:[278]***layer3.6.se.conv.weight:[1, 1, 5]***layer3.7.conv1.weight:[244, 278, 1, 1]***layer3.7.bn1.weight:[244]***layer3.7.conv2.weight:[154, 244, 3, 3]***layer3.7.bn2.weight:[154]***layer3.7.conv3.weight:[278, 154, 1, 1]***layer3.7.bn3.weight:[278]***layer3.7.se.conv.weight:[1, 1, 5]***layer3.8.conv1.weight:[244, 278, 1, 1]***layer3.8.bn1.weight:[244]***layer3.8.conv2.weight:[159, 244, 3, 3]***layer3.8.bn2.weight:[159]***layer3.8.conv3.weight:[278, 159, 1, 1]***layer3.8.bn3.weight:[278]***layer3.8.se.conv.weight:[1, 1, 5]***layer3.9.conv1.weight:[238, 278, 1, 1]***layer3.9.bn1.weight:[238]***layer3.9.conv2.weight:[97, 238, 3, 3]***layer3.9.bn2.weight:[97]***layer3.9.conv3.weight:[278, 97, 1, 1]***layer3.9.bn3.weight:[278]***layer3.9.se.conv.weight:[1, 1, 5]***layer3.10.conv1.weight:[244, 278, 1, 1]***layer3.10.bn1.weight:[244]***layer3.10.conv2.weight:[149, 244, 3, 3]***layer3.10.bn2.weight:[149]***layer3.10.conv3.weight:[278, 149, 1, 1]***layer3.10.bn3.weight:[278]***layer3.10.se.conv.weight:[1, 1, 5]***layer3.11.conv1.weight:[253, 278, 1, 1]***layer3.11.bn1.weight:[253]***layer3.11.conv2.weight:[181, 253, 3, 3]***layer3.11.bn2.weight:[181]***layer3.11.conv3.weight:[278, 181, 1, 1]***layer3.11.bn3.weight:[278]***layer3.11.se.conv.weight:[1, 1, 5]***layer3.12.conv1.weight:[245, 278, 1, 1]***layer3.12.bn1.weight:[245]***layer3.12.conv2.weight:[119, 245, 3, 3]***layer3.12.bn2.weight:[119]***layer3.12.conv3.weight:[278, 119, 1, 1]***layer3.12.bn3.weight:[278]***layer3.12.se.conv.weight:[1, 1, 5]***layer3.13.conv1.weight:[255, 278, 1, 1]***layer3.13.bn1.weight:[255]***layer3.13.conv2.weight:[216, 255, 3, 3]***layer3.13.bn2.weight:[216]***layer3.13.conv3.weight:[278, 216, 1, 1]***layer3.13.bn3.weight:[278]***layer3.13.se.conv.weight:[1, 1, 5]***layer3.14.conv1.weight:[256, 278, 1, 1]***layer3.14.bn1.weight:[256]***layer3.14.conv2.weight:[201, 256, 3, 3]***layer3.14.bn2.weight:[201]***layer3.14.conv3.weight:[278, 201, 1, 1]***layer3.14.bn3.weight:[278]***layer3.14.se.conv.weight:[1, 1, 5]***layer3.15.conv1.weight:[253, 278, 1, 1]***layer3.15.bn1.weight:[253]***layer3.15.conv2.weight:[149, 253, 3, 3]***layer3.15.bn2.weight:[149]***layer3.15.conv3.weight:[278, 149, 1, 1]***layer3.15.bn3.weight:[278]***layer3.15.se.conv.weight:[1, 1, 5]***layer3.16.conv1.weight:[254, 278, 1, 1]***layer3.16.bn1.weight:[254]***layer3.16.conv2.weight:[141, 254, 3, 3]***layer3.16.bn2.weight:[141]***layer3.16.conv3.weight:[278, 141, 1, 1]***layer3.16.bn3.weight:[278]***layer3.16.se.conv.weight:[1, 1, 5]***layer3.17.conv1.weight:[256, 278, 1, 1]***layer3.17.bn1.weight:[256]***layer3.17.conv2.weight:[190, 256, 3, 3]***layer3.17.bn2.weight:[190]***layer3.17.conv3.weight:[278, 190, 1, 1]***layer3.17.bn3.weight:[278]***layer3.17.se.conv.weight:[1, 1, 5]***layer3.18.conv1.weight:[256, 278, 1, 1]***layer3.18.bn1.weight:[256]***layer3.18.conv2.weight:[217, 256, 3, 3]***layer3.18.bn2.weight:[217]***layer3.18.conv3.weight:[278, 217, 1, 1]***layer3.18.bn3.weight:[278]***layer3.18.se.conv.weight:[1, 1, 5]***layer3.19.conv1.weight:[255, 278, 1, 1]***layer3.19.bn1.weight:[255]***layer3.19.conv2.weight:[156, 255, 3, 3]***layer3.19.bn2.weight:[156]***layer3.19.conv3.weight:[278, 156, 1, 1]***layer3.19.bn3.weight:[278]***layer3.19.se.conv.weight:[1, 1, 5]***layer3.20.conv1.weight:[256, 278, 1, 1]***layer3.20.bn1.weight:[256]***layer3.20.conv2.weight:[155, 256, 3, 3]***layer3.20.bn2.weight:[155]***layer3.20.conv3.weight:[278, 155, 1, 1]***layer3.20.bn3.weight:[278]***layer3.20.se.conv.weight:[1, 1, 5]***layer3.21.conv1.weight:[256, 278, 1, 1]***layer3.21.bn1.weight:[256]***layer3.21.conv2.weight:[232, 256, 3, 3]***layer3.21.bn2.weight:[232]***layer3.21.conv3.weight:[278, 232, 1, 1]***layer3.21.bn3.weight:[278]***layer3.21.se.conv.weight:[1, 1, 5]***layer3.22.conv1.weight:[256, 278, 1, 1]***layer3.22.bn1.weight:[256]***layer3.22.conv2.weight:[214, 256, 3, 3]***layer3.22.bn2.weight:[214]***layer3.22.conv3.weight:[278, 214, 1, 1]***layer3.22.bn3.weight:[278]***layer3.22.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[499, 278, 1, 1]***layer4.0.bn1.weight:[499]***layer4.0.conv2.weight:[289, 499, 3, 3]***layer4.0.bn2.weight:[289]***layer4.0.conv3.weight:[2042, 289, 1, 1]***layer4.0.bn3.weight:[2042]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2042, 278, 1, 1]***layer4.0.downsample.2.weight:[2042]***layer4.1.conv1.weight:[512, 2042, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[512, 512, 3, 3]***layer4.1.bn2.weight:[512]***layer4.1.conv3.weight:[2042, 512, 1, 1]***layer4.1.bn3.weight:[2042]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2042, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[502, 512, 3, 3]***layer4.2.bn2.weight:[502]***layer4.2.conv3.weight:[2042, 502, 1, 1]***layer4.2.bn3.weight:[2042]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2042]***layer1_2_conv3_M.weight:[256, 26]***layer2_3_conv3_M.weight:[512, 142]***layer3_22_conv3_M.weight:[1024, 278]***layer4_2_conv3_M.weight:[2048, 2042] \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/pruned/ecaresnet50d_pruned.txt b/data_processing/MANIQA/timm/models/pruned/ecaresnet50d_pruned.txt new file mode 100644 index 0000000..9a8b2bf --- /dev/null +++ b/data_processing/MANIQA/timm/models/pruned/ecaresnet50d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/pruned/efficientnet_b1_pruned.txt b/data_processing/MANIQA/timm/models/pruned/efficientnet_b1_pruned.txt new file mode 100644 index 0000000..0972b52 --- /dev/null +++ b/data_processing/MANIQA/timm/models/pruned/efficientnet_b1_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[48, 16, 1, 1]***blocks.1.0.bn1.weight:[48]***blocks.1.0.bn1.bias:[48]***blocks.1.0.bn1.running_mean:[48]***blocks.1.0.bn1.running_var:[48]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[48, 1, 3, 3]***blocks.1.0.bn2.weight:[48]***blocks.1.0.bn2.bias:[48]***blocks.1.0.bn2.running_mean:[48]***blocks.1.0.bn2.running_var:[48]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 48, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[48, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[48]***blocks.1.0.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[62, 12, 1, 1]***blocks.1.1.bn1.weight:[62]***blocks.1.1.bn1.bias:[62]***blocks.1.1.bn1.running_mean:[62]***blocks.1.1.bn1.running_var:[62]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[62, 1, 3, 3]***blocks.1.1.bn2.weight:[62]***blocks.1.1.bn2.bias:[62]***blocks.1.1.bn2.running_mean:[62]***blocks.1.1.bn2.running_var:[62]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 62, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[62, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[62]***blocks.1.1.conv_pwl.weight:[12, 62, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[48, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[70, 12, 1, 1]***blocks.2.0.bn1.weight:[70]***blocks.2.0.bn1.bias:[70]***blocks.2.0.bn1.running_mean:[70]***blocks.2.0.bn1.running_var:[70]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[70, 1, 5, 5]***blocks.2.0.bn2.weight:[70]***blocks.2.0.bn2.bias:[70]***blocks.2.0.bn2.running_mean:[70]***blocks.2.0.bn2.running_var:[70]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 70, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[70, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[70]***blocks.2.0.conv_pwl.weight:[35, 70, 1, 1]***blocks.2.0.bn3.weight:[35]***blocks.2.0.bn3.bias:[35]***blocks.2.0.bn3.running_mean:[35]***blocks.2.0.bn3.running_var:[35]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[61, 35, 1, 1]***blocks.2.1.bn1.weight:[61]***blocks.2.1.bn1.bias:[61]***blocks.2.1.bn1.running_mean:[61]***blocks.2.1.bn1.running_var:[61]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[61, 1, 5, 5]***blocks.2.1.bn2.weight:[61]***blocks.2.1.bn2.bias:[61]***blocks.2.1.bn2.running_mean:[61]***blocks.2.1.bn2.running_var:[61]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[10, 61, 1, 1]***blocks.2.1.se.conv_reduce.bias:[10]***blocks.2.1.se.conv_expand.weight:[61, 10, 1, 1]***blocks.2.1.se.conv_expand.bias:[61]***blocks.2.1.conv_pwl.weight:[35, 61, 1, 1]***blocks.2.1.bn3.weight:[35]***blocks.2.1.bn3.bias:[35]***blocks.2.1.bn3.running_mean:[35]***blocks.2.1.bn3.running_var:[35]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[51, 35, 1, 1]***blocks.2.2.bn1.weight:[51]***blocks.2.2.bn1.bias:[51]***blocks.2.2.bn1.running_mean:[51]***blocks.2.2.bn1.running_var:[51]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[51, 1, 5, 5]***blocks.2.2.bn2.weight:[51]***blocks.2.2.bn2.bias:[51]***blocks.2.2.bn2.running_mean:[51]***blocks.2.2.bn2.running_var:[51]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[10, 51, 1, 1]***blocks.2.2.se.conv_reduce.bias:[10]***blocks.2.2.se.conv_expand.weight:[51, 10, 1, 1]***blocks.2.2.se.conv_expand.bias:[51]***blocks.2.2.conv_pwl.weight:[35, 51, 1, 1]***blocks.2.2.bn3.weight:[35]***blocks.2.2.bn3.bias:[35]***blocks.2.2.bn3.running_mean:[35]***blocks.2.2.bn3.running_var:[35]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[175, 35, 1, 1]***blocks.3.0.bn1.weight:[175]***blocks.3.0.bn1.bias:[175]***blocks.3.0.bn1.running_mean:[175]***blocks.3.0.bn1.running_var:[175]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[175, 1, 3, 3]***blocks.3.0.bn2.weight:[175]***blocks.3.0.bn2.bias:[175]***blocks.3.0.bn2.running_mean:[175]***blocks.3.0.bn2.running_var:[175]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[10, 175, 1, 1]***blocks.3.0.se.conv_reduce.bias:[10]***blocks.3.0.se.conv_expand.weight:[175, 10, 1, 1]***blocks.3.0.se.conv_expand.bias:[175]***blocks.3.0.conv_pwl.weight:[74, 175, 1, 1]***blocks.3.0.bn3.weight:[74]***blocks.3.0.bn3.bias:[74]***blocks.3.0.bn3.running_mean:[74]***blocks.3.0.bn3.running_var:[74]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[188, 74, 1, 1]***blocks.3.1.bn1.weight:[188]***blocks.3.1.bn1.bias:[188]***blocks.3.1.bn1.running_mean:[188]***blocks.3.1.bn1.running_var:[188]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[188, 1, 3, 3]***blocks.3.1.bn2.weight:[188]***blocks.3.1.bn2.bias:[188]***blocks.3.1.bn2.running_mean:[188]***blocks.3.1.bn2.running_var:[188]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[20, 188, 1, 1]***blocks.3.1.se.conv_reduce.bias:[20]***blocks.3.1.se.conv_expand.weight:[188, 20, 1, 1]***blocks.3.1.se.conv_expand.bias:[188]***blocks.3.1.conv_pwl.weight:[74, 188, 1, 1]***blocks.3.1.bn3.weight:[74]***blocks.3.1.bn3.bias:[74]***blocks.3.1.bn3.running_mean:[74]***blocks.3.1.bn3.running_var:[74]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[137, 74, 1, 1]***blocks.3.2.bn1.weight:[137]***blocks.3.2.bn1.bias:[137]***blocks.3.2.bn1.running_mean:[137]***blocks.3.2.bn1.running_var:[137]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[137, 1, 3, 3]***blocks.3.2.bn2.weight:[137]***blocks.3.2.bn2.bias:[137]***blocks.3.2.bn2.running_mean:[137]***blocks.3.2.bn2.running_var:[137]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[20, 137, 1, 1]***blocks.3.2.se.conv_reduce.bias:[20]***blocks.3.2.se.conv_expand.weight:[137, 20, 1, 1]***blocks.3.2.se.conv_expand.bias:[137]***blocks.3.2.conv_pwl.weight:[74, 137, 1, 1]***blocks.3.2.bn3.weight:[74]***blocks.3.2.bn3.bias:[74]***blocks.3.2.bn3.running_mean:[74]***blocks.3.2.bn3.running_var:[74]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[164, 74, 1, 1]***blocks.3.3.bn1.weight:[164]***blocks.3.3.bn1.bias:[164]***blocks.3.3.bn1.running_mean:[164]***blocks.3.3.bn1.running_var:[164]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[164, 1, 3, 3]***blocks.3.3.bn2.weight:[164]***blocks.3.3.bn2.bias:[164]***blocks.3.3.bn2.running_mean:[164]***blocks.3.3.bn2.running_var:[164]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[20, 164, 1, 1]***blocks.3.3.se.conv_reduce.bias:[20]***blocks.3.3.se.conv_expand.weight:[164, 20, 1, 1]***blocks.3.3.se.conv_expand.bias:[164]***blocks.3.3.conv_pwl.weight:[74, 164, 1, 1]***blocks.3.3.bn3.weight:[74]***blocks.3.3.bn3.bias:[74]***blocks.3.3.bn3.running_mean:[74]***blocks.3.3.bn3.running_var:[74]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[399, 74, 1, 1]***blocks.4.0.bn1.weight:[399]***blocks.4.0.bn1.bias:[399]***blocks.4.0.bn1.running_mean:[399]***blocks.4.0.bn1.running_var:[399]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[399, 1, 5, 5]***blocks.4.0.bn2.weight:[399]***blocks.4.0.bn2.bias:[399]***blocks.4.0.bn2.running_mean:[399]***blocks.4.0.bn2.running_var:[399]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[20, 399, 1, 1]***blocks.4.0.se.conv_reduce.bias:[20]***blocks.4.0.se.conv_expand.weight:[399, 20, 1, 1]***blocks.4.0.se.conv_expand.bias:[399]***blocks.4.0.conv_pwl.weight:[67, 399, 1, 1]***blocks.4.0.bn3.weight:[67]***blocks.4.0.bn3.bias:[67]***blocks.4.0.bn3.running_mean:[67]***blocks.4.0.bn3.running_var:[67]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[201, 67, 1, 1]***blocks.4.1.bn1.weight:[201]***blocks.4.1.bn1.bias:[201]***blocks.4.1.bn1.running_mean:[201]***blocks.4.1.bn1.running_var:[201]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[201, 1, 5, 5]***blocks.4.1.bn2.weight:[201]***blocks.4.1.bn2.bias:[201]***blocks.4.1.bn2.running_mean:[201]***blocks.4.1.bn2.running_var:[201]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[28, 201, 1, 1]***blocks.4.1.se.conv_reduce.bias:[28]***blocks.4.1.se.conv_expand.weight:[201, 28, 1, 1]***blocks.4.1.se.conv_expand.bias:[201]***blocks.4.1.conv_pwl.weight:[67, 201, 1, 1]***blocks.4.1.bn3.weight:[67]***blocks.4.1.bn3.bias:[67]***blocks.4.1.bn3.running_mean:[67]***blocks.4.1.bn3.running_var:[67]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[160, 67, 1, 1]***blocks.4.2.bn1.weight:[160]***blocks.4.2.bn1.bias:[160]***blocks.4.2.bn1.running_mean:[160]***blocks.4.2.bn1.running_var:[160]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[160, 1, 5, 5]***blocks.4.2.bn2.weight:[160]***blocks.4.2.bn2.bias:[160]***blocks.4.2.bn2.running_mean:[160]***blocks.4.2.bn2.running_var:[160]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[28, 160, 1, 1]***blocks.4.2.se.conv_reduce.bias:[28]***blocks.4.2.se.conv_expand.weight:[160, 28, 1, 1]***blocks.4.2.se.conv_expand.bias:[160]***blocks.4.2.conv_pwl.weight:[67, 160, 1, 1]***blocks.4.2.bn3.weight:[67]***blocks.4.2.bn3.bias:[67]***blocks.4.2.bn3.running_mean:[67]***blocks.4.2.bn3.running_var:[67]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[213, 67, 1, 1]***blocks.4.3.bn1.weight:[213]***blocks.4.3.bn1.bias:[213]***blocks.4.3.bn1.running_mean:[213]***blocks.4.3.bn1.running_var:[213]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[213, 1, 5, 5]***blocks.4.3.bn2.weight:[213]***blocks.4.3.bn2.bias:[213]***blocks.4.3.bn2.running_mean:[213]***blocks.4.3.bn2.running_var:[213]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[28, 213, 1, 1]***blocks.4.3.se.conv_reduce.bias:[28]***blocks.4.3.se.conv_expand.weight:[213, 28, 1, 1]***blocks.4.3.se.conv_expand.bias:[213]***blocks.4.3.conv_pwl.weight:[67, 213, 1, 1]***blocks.4.3.bn3.weight:[67]***blocks.4.3.bn3.bias:[67]***blocks.4.3.bn3.running_mean:[67]***blocks.4.3.bn3.running_var:[67]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[637, 67, 1, 1]***blocks.5.0.bn1.weight:[637]***blocks.5.0.bn1.bias:[637]***blocks.5.0.bn1.running_mean:[637]***blocks.5.0.bn1.running_var:[637]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[637, 1, 5, 5]***blocks.5.0.bn2.weight:[637]***blocks.5.0.bn2.bias:[637]***blocks.5.0.bn2.running_mean:[637]***blocks.5.0.bn2.running_var:[637]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[27, 637, 1, 1]***blocks.5.0.se.conv_reduce.bias:[27]***blocks.5.0.se.conv_expand.weight:[637, 27, 1, 1]***blocks.5.0.se.conv_expand.bias:[637]***blocks.5.0.conv_pwl.weight:[192, 637, 1, 1]***blocks.5.0.bn3.weight:[192]***blocks.5.0.bn3.bias:[192]***blocks.5.0.bn3.running_mean:[192]***blocks.5.0.bn3.running_var:[192]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[806, 192, 1, 1]***blocks.5.1.bn1.weight:[806]***blocks.5.1.bn1.bias:[806]***blocks.5.1.bn1.running_mean:[806]***blocks.5.1.bn1.running_var:[806]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[806, 1, 5, 5]***blocks.5.1.bn2.weight:[806]***blocks.5.1.bn2.bias:[806]***blocks.5.1.bn2.running_mean:[806]***blocks.5.1.bn2.running_var:[806]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[48, 806, 1, 1]***blocks.5.1.se.conv_reduce.bias:[48]***blocks.5.1.se.conv_expand.weight:[806, 48, 1, 1]***blocks.5.1.se.conv_expand.bias:[806]***blocks.5.1.conv_pwl.weight:[192, 806, 1, 1]***blocks.5.1.bn3.weight:[192]***blocks.5.1.bn3.bias:[192]***blocks.5.1.bn3.running_mean:[192]***blocks.5.1.bn3.running_var:[192]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[798, 192, 1, 1]***blocks.5.2.bn1.weight:[798]***blocks.5.2.bn1.bias:[798]***blocks.5.2.bn1.running_mean:[798]***blocks.5.2.bn1.running_var:[798]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[798, 1, 5, 5]***blocks.5.2.bn2.weight:[798]***blocks.5.2.bn2.bias:[798]***blocks.5.2.bn2.running_mean:[798]***blocks.5.2.bn2.running_var:[798]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[48, 798, 1, 1]***blocks.5.2.se.conv_reduce.bias:[48]***blocks.5.2.se.conv_expand.weight:[798, 48, 1, 1]***blocks.5.2.se.conv_expand.bias:[798]***blocks.5.2.conv_pwl.weight:[192, 798, 1, 1]***blocks.5.2.bn3.weight:[192]***blocks.5.2.bn3.bias:[192]***blocks.5.2.bn3.running_mean:[192]***blocks.5.2.bn3.running_var:[192]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[891, 192, 1, 1]***blocks.5.3.bn1.weight:[891]***blocks.5.3.bn1.bias:[891]***blocks.5.3.bn1.running_mean:[891]***blocks.5.3.bn1.running_var:[891]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[891, 1, 5, 5]***blocks.5.3.bn2.weight:[891]***blocks.5.3.bn2.bias:[891]***blocks.5.3.bn2.running_mean:[891]***blocks.5.3.bn2.running_var:[891]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[48, 891, 1, 1]***blocks.5.3.se.conv_reduce.bias:[48]***blocks.5.3.se.conv_expand.weight:[891, 48, 1, 1]***blocks.5.3.se.conv_expand.bias:[891]***blocks.5.3.conv_pwl.weight:[192, 891, 1, 1]***blocks.5.3.bn3.weight:[192]***blocks.5.3.bn3.bias:[192]***blocks.5.3.bn3.running_mean:[192]***blocks.5.3.bn3.running_var:[192]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[990, 192, 1, 1]***blocks.5.4.bn1.weight:[990]***blocks.5.4.bn1.bias:[990]***blocks.5.4.bn1.running_mean:[990]***blocks.5.4.bn1.running_var:[990]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[990, 1, 5, 5]***blocks.5.4.bn2.weight:[990]***blocks.5.4.bn2.bias:[990]***blocks.5.4.bn2.running_mean:[990]***blocks.5.4.bn2.running_var:[990]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[48, 990, 1, 1]***blocks.5.4.se.conv_reduce.bias:[48]***blocks.5.4.se.conv_expand.weight:[990, 48, 1, 1]***blocks.5.4.se.conv_expand.bias:[990]***blocks.5.4.conv_pwl.weight:[192, 990, 1, 1]***blocks.5.4.bn3.weight:[192]***blocks.5.4.bn3.bias:[192]***blocks.5.4.bn3.running_mean:[192]***blocks.5.4.bn3.running_var:[192]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1152, 192, 1, 1]***blocks.6.0.bn1.weight:[1152]***blocks.6.0.bn1.bias:[1152]***blocks.6.0.bn1.running_mean:[1152]***blocks.6.0.bn1.running_var:[1152]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1152, 1, 3, 3]***blocks.6.0.bn2.weight:[1152]***blocks.6.0.bn2.bias:[1152]***blocks.6.0.bn2.running_mean:[1152]***blocks.6.0.bn2.running_var:[1152]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[48, 1152, 1, 1]***blocks.6.0.se.conv_reduce.bias:[48]***blocks.6.0.se.conv_expand.weight:[1152, 48, 1, 1]***blocks.6.0.se.conv_expand.bias:[1152]***blocks.6.0.conv_pwl.weight:[320, 1152, 1, 1]***blocks.6.0.bn3.weight:[320]***blocks.6.0.bn3.bias:[320]***blocks.6.0.bn3.running_mean:[320]***blocks.6.0.bn3.running_var:[320]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[1912, 320, 1, 1]***blocks.6.1.bn1.weight:[1912]***blocks.6.1.bn1.bias:[1912]***blocks.6.1.bn1.running_mean:[1912]***blocks.6.1.bn1.running_var:[1912]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[1912, 1, 3, 3]***blocks.6.1.bn2.weight:[1912]***blocks.6.1.bn2.bias:[1912]***blocks.6.1.bn2.running_mean:[1912]***blocks.6.1.bn2.running_var:[1912]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[80, 1912, 1, 1]***blocks.6.1.se.conv_reduce.bias:[80]***blocks.6.1.se.conv_expand.weight:[1912, 80, 1, 1]***blocks.6.1.se.conv_expand.bias:[1912]***blocks.6.1.conv_pwl.weight:[320, 1912, 1, 1]***blocks.6.1.bn3.weight:[320]***blocks.6.1.bn3.bias:[320]***blocks.6.1.bn3.running_mean:[320]***blocks.6.1.bn3.running_var:[320]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1280, 320, 1, 1]***bn2.weight:[1280]***bn2.bias:[1280]***bn2.running_mean:[1280]***bn2.running_var:[1280]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1280]***classifier.bias:[1000] \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/pruned/efficientnet_b2_pruned.txt b/data_processing/MANIQA/timm/models/pruned/efficientnet_b2_pruned.txt new file mode 100644 index 0000000..6e3fade --- /dev/null +++ b/data_processing/MANIQA/timm/models/pruned/efficientnet_b2_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[54, 16, 1, 1]***blocks.1.0.bn1.weight:[54]***blocks.1.0.bn1.bias:[54]***blocks.1.0.bn1.running_mean:[54]***blocks.1.0.bn1.running_var:[54]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[54, 1, 3, 3]***blocks.1.0.bn2.weight:[54]***blocks.1.0.bn2.bias:[54]***blocks.1.0.bn2.running_mean:[54]***blocks.1.0.bn2.running_var:[54]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 54, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[54, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[54]***blocks.1.0.conv_pwl.weight:[17, 54, 1, 1]***blocks.1.0.bn3.weight:[17]***blocks.1.0.bn3.bias:[17]***blocks.1.0.bn3.running_mean:[17]***blocks.1.0.bn3.running_var:[17]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[69, 17, 1, 1]***blocks.1.1.bn1.weight:[69]***blocks.1.1.bn1.bias:[69]***blocks.1.1.bn1.running_mean:[69]***blocks.1.1.bn1.running_var:[69]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[69, 1, 3, 3]***blocks.1.1.bn2.weight:[69]***blocks.1.1.bn2.bias:[69]***blocks.1.1.bn2.running_mean:[69]***blocks.1.1.bn2.running_var:[69]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 69, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[69, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[69]***blocks.1.1.conv_pwl.weight:[17, 69, 1, 1]***blocks.1.1.bn3.weight:[17]***blocks.1.1.bn3.bias:[17]***blocks.1.1.bn3.running_mean:[17]***blocks.1.1.bn3.running_var:[17]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[61, 17, 1, 1]***blocks.1.2.bn1.weight:[61]***blocks.1.2.bn1.bias:[61]***blocks.1.2.bn1.running_mean:[61]***blocks.1.2.bn1.running_var:[61]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[61, 1, 3, 3]***blocks.1.2.bn2.weight:[61]***blocks.1.2.bn2.bias:[61]***blocks.1.2.bn2.running_mean:[61]***blocks.1.2.bn2.running_var:[61]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 61, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[61, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[61]***blocks.1.2.conv_pwl.weight:[17, 61, 1, 1]***blocks.1.2.bn3.weight:[17]***blocks.1.2.bn3.bias:[17]***blocks.1.2.bn3.running_mean:[17]***blocks.1.2.bn3.running_var:[17]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[86, 17, 1, 1]***blocks.2.0.bn1.weight:[86]***blocks.2.0.bn1.bias:[86]***blocks.2.0.bn1.running_mean:[86]***blocks.2.0.bn1.running_var:[86]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[86, 1, 5, 5]***blocks.2.0.bn2.weight:[86]***blocks.2.0.bn2.bias:[86]***blocks.2.0.bn2.running_mean:[86]***blocks.2.0.bn2.running_var:[86]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 86, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[86, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[86]***blocks.2.0.conv_pwl.weight:[42, 86, 1, 1]***blocks.2.0.bn3.weight:[42]***blocks.2.0.bn3.bias:[42]***blocks.2.0.bn3.running_mean:[42]***blocks.2.0.bn3.running_var:[42]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[72, 42, 1, 1]***blocks.2.1.bn1.weight:[72]***blocks.2.1.bn1.bias:[72]***blocks.2.1.bn1.running_mean:[72]***blocks.2.1.bn1.running_var:[72]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[72, 1, 5, 5]***blocks.2.1.bn2.weight:[72]***blocks.2.1.bn2.bias:[72]***blocks.2.1.bn2.running_mean:[72]***blocks.2.1.bn2.running_var:[72]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 72, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[72, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[72]***blocks.2.1.conv_pwl.weight:[42, 72, 1, 1]***blocks.2.1.bn3.weight:[42]***blocks.2.1.bn3.bias:[42]***blocks.2.1.bn3.running_mean:[42]***blocks.2.1.bn3.running_var:[42]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[98, 42, 1, 1]***blocks.2.2.bn1.weight:[98]***blocks.2.2.bn1.bias:[98]***blocks.2.2.bn1.running_mean:[98]***blocks.2.2.bn1.running_var:[98]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[98, 1, 5, 5]***blocks.2.2.bn2.weight:[98]***blocks.2.2.bn2.bias:[98]***blocks.2.2.bn2.running_mean:[98]***blocks.2.2.bn2.running_var:[98]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 98, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[98, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[98]***blocks.2.2.conv_pwl.weight:[42, 98, 1, 1]***blocks.2.2.bn3.weight:[42]***blocks.2.2.bn3.bias:[42]***blocks.2.2.bn3.running_mean:[42]***blocks.2.2.bn3.running_var:[42]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[245, 42, 1, 1]***blocks.3.0.bn1.weight:[245]***blocks.3.0.bn1.bias:[245]***blocks.3.0.bn1.running_mean:[245]***blocks.3.0.bn1.running_var:[245]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[245, 1, 3, 3]***blocks.3.0.bn2.weight:[245]***blocks.3.0.bn2.bias:[245]***blocks.3.0.bn2.running_mean:[245]***blocks.3.0.bn2.running_var:[245]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 245, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[245, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[245]***blocks.3.0.conv_pwl.weight:[85, 245, 1, 1]***blocks.3.0.bn3.weight:[85]***blocks.3.0.bn3.bias:[85]***blocks.3.0.bn3.running_mean:[85]***blocks.3.0.bn3.running_var:[85]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[274, 85, 1, 1]***blocks.3.1.bn1.weight:[274]***blocks.3.1.bn1.bias:[274]***blocks.3.1.bn1.running_mean:[274]***blocks.3.1.bn1.running_var:[274]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[274, 1, 3, 3]***blocks.3.1.bn2.weight:[274]***blocks.3.1.bn2.bias:[274]***blocks.3.1.bn2.running_mean:[274]***blocks.3.1.bn2.running_var:[274]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[22, 274, 1, 1]***blocks.3.1.se.conv_reduce.bias:[22]***blocks.3.1.se.conv_expand.weight:[274, 22, 1, 1]***blocks.3.1.se.conv_expand.bias:[274]***blocks.3.1.conv_pwl.weight:[85, 274, 1, 1]***blocks.3.1.bn3.weight:[85]***blocks.3.1.bn3.bias:[85]***blocks.3.1.bn3.running_mean:[85]***blocks.3.1.bn3.running_var:[85]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[254, 85, 1, 1]***blocks.3.2.bn1.weight:[254]***blocks.3.2.bn1.bias:[254]***blocks.3.2.bn1.running_mean:[254]***blocks.3.2.bn1.running_var:[254]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[254, 1, 3, 3]***blocks.3.2.bn2.weight:[254]***blocks.3.2.bn2.bias:[254]***blocks.3.2.bn2.running_mean:[254]***blocks.3.2.bn2.running_var:[254]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[22, 254, 1, 1]***blocks.3.2.se.conv_reduce.bias:[22]***blocks.3.2.se.conv_expand.weight:[254, 22, 1, 1]***blocks.3.2.se.conv_expand.bias:[254]***blocks.3.2.conv_pwl.weight:[85, 254, 1, 1]***blocks.3.2.bn3.weight:[85]***blocks.3.2.bn3.bias:[85]***blocks.3.2.bn3.running_mean:[85]***blocks.3.2.bn3.running_var:[85]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[292, 85, 1, 1]***blocks.3.3.bn1.weight:[292]***blocks.3.3.bn1.bias:[292]***blocks.3.3.bn1.running_mean:[292]***blocks.3.3.bn1.running_var:[292]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[292, 1, 3, 3]***blocks.3.3.bn2.weight:[292]***blocks.3.3.bn2.bias:[292]***blocks.3.3.bn2.running_mean:[292]***blocks.3.3.bn2.running_var:[292]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[22, 292, 1, 1]***blocks.3.3.se.conv_reduce.bias:[22]***blocks.3.3.se.conv_expand.weight:[292, 22, 1, 1]***blocks.3.3.se.conv_expand.bias:[292]***blocks.3.3.conv_pwl.weight:[85, 292, 1, 1]***blocks.3.3.bn3.weight:[85]***blocks.3.3.bn3.bias:[85]***blocks.3.3.bn3.running_mean:[85]***blocks.3.3.bn3.running_var:[85]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[502, 85, 1, 1]***blocks.4.0.bn1.weight:[502]***blocks.4.0.bn1.bias:[502]***blocks.4.0.bn1.running_mean:[502]***blocks.4.0.bn1.running_var:[502]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[502, 1, 5, 5]***blocks.4.0.bn2.weight:[502]***blocks.4.0.bn2.bias:[502]***blocks.4.0.bn2.running_mean:[502]***blocks.4.0.bn2.running_var:[502]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[22, 502, 1, 1]***blocks.4.0.se.conv_reduce.bias:[22]***blocks.4.0.se.conv_expand.weight:[502, 22, 1, 1]***blocks.4.0.se.conv_expand.bias:[502]***blocks.4.0.conv_pwl.weight:[116, 502, 1, 1]***blocks.4.0.bn3.weight:[116]***blocks.4.0.bn3.bias:[116]***blocks.4.0.bn3.running_mean:[116]***blocks.4.0.bn3.running_var:[116]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[315, 116, 1, 1]***blocks.4.1.bn1.weight:[315]***blocks.4.1.bn1.bias:[315]***blocks.4.1.bn1.running_mean:[315]***blocks.4.1.bn1.running_var:[315]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[315, 1, 5, 5]***blocks.4.1.bn2.weight:[315]***blocks.4.1.bn2.bias:[315]***blocks.4.1.bn2.running_mean:[315]***blocks.4.1.bn2.running_var:[315]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[30, 315, 1, 1]***blocks.4.1.se.conv_reduce.bias:[30]***blocks.4.1.se.conv_expand.weight:[315, 30, 1, 1]***blocks.4.1.se.conv_expand.bias:[315]***blocks.4.1.conv_pwl.weight:[116, 315, 1, 1]***blocks.4.1.bn3.weight:[116]***blocks.4.1.bn3.bias:[116]***blocks.4.1.bn3.running_mean:[116]***blocks.4.1.bn3.running_var:[116]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[354, 116, 1, 1]***blocks.4.2.bn1.weight:[354]***blocks.4.2.bn1.bias:[354]***blocks.4.2.bn1.running_mean:[354]***blocks.4.2.bn1.running_var:[354]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[354, 1, 5, 5]***blocks.4.2.bn2.weight:[354]***blocks.4.2.bn2.bias:[354]***blocks.4.2.bn2.running_mean:[354]***blocks.4.2.bn2.running_var:[354]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[30, 354, 1, 1]***blocks.4.2.se.conv_reduce.bias:[30]***blocks.4.2.se.conv_expand.weight:[354, 30, 1, 1]***blocks.4.2.se.conv_expand.bias:[354]***blocks.4.2.conv_pwl.weight:[116, 354, 1, 1]***blocks.4.2.bn3.weight:[116]***blocks.4.2.bn3.bias:[116]***blocks.4.2.bn3.running_mean:[116]***blocks.4.2.bn3.running_var:[116]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[443, 116, 1, 1]***blocks.4.3.bn1.weight:[443]***blocks.4.3.bn1.bias:[443]***blocks.4.3.bn1.running_mean:[443]***blocks.4.3.bn1.running_var:[443]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[443, 1, 5, 5]***blocks.4.3.bn2.weight:[443]***blocks.4.3.bn2.bias:[443]***blocks.4.3.bn2.running_mean:[443]***blocks.4.3.bn2.running_var:[443]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[30, 443, 1, 1]***blocks.4.3.se.conv_reduce.bias:[30]***blocks.4.3.se.conv_expand.weight:[443, 30, 1, 1]***blocks.4.3.se.conv_expand.bias:[443]***blocks.4.3.conv_pwl.weight:[116, 443, 1, 1]***blocks.4.3.bn3.weight:[116]***blocks.4.3.bn3.bias:[116]***blocks.4.3.bn3.running_mean:[116]***blocks.4.3.bn3.running_var:[116]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[719, 116, 1, 1]***blocks.5.0.bn1.weight:[719]***blocks.5.0.bn1.bias:[719]***blocks.5.0.bn1.running_mean:[719]***blocks.5.0.bn1.running_var:[719]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[719, 1, 5, 5]***blocks.5.0.bn2.weight:[719]***blocks.5.0.bn2.bias:[719]***blocks.5.0.bn2.running_mean:[719]***blocks.5.0.bn2.running_var:[719]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[30, 719, 1, 1]***blocks.5.0.se.conv_reduce.bias:[30]***blocks.5.0.se.conv_expand.weight:[719, 30, 1, 1]***blocks.5.0.se.conv_expand.bias:[719]***blocks.5.0.conv_pwl.weight:[208, 719, 1, 1]***blocks.5.0.bn3.weight:[208]***blocks.5.0.bn3.bias:[208]***blocks.5.0.bn3.running_mean:[208]***blocks.5.0.bn3.running_var:[208]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1148, 208, 1, 1]***blocks.5.1.bn1.weight:[1148]***blocks.5.1.bn1.bias:[1148]***blocks.5.1.bn1.running_mean:[1148]***blocks.5.1.bn1.running_var:[1148]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1148, 1, 5, 5]***blocks.5.1.bn2.weight:[1148]***blocks.5.1.bn2.bias:[1148]***blocks.5.1.bn2.running_mean:[1148]***blocks.5.1.bn2.running_var:[1148]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[52, 1148, 1, 1]***blocks.5.1.se.conv_reduce.bias:[52]***blocks.5.1.se.conv_expand.weight:[1148, 52, 1, 1]***blocks.5.1.se.conv_expand.bias:[1148]***blocks.5.1.conv_pwl.weight:[208, 1148, 1, 1]***blocks.5.1.bn3.weight:[208]***blocks.5.1.bn3.bias:[208]***blocks.5.1.bn3.running_mean:[208]***blocks.5.1.bn3.running_var:[208]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[1160, 208, 1, 1]***blocks.5.2.bn1.weight:[1160]***blocks.5.2.bn1.bias:[1160]***blocks.5.2.bn1.running_mean:[1160]***blocks.5.2.bn1.running_var:[1160]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[1160, 1, 5, 5]***blocks.5.2.bn2.weight:[1160]***blocks.5.2.bn2.bias:[1160]***blocks.5.2.bn2.running_mean:[1160]***blocks.5.2.bn2.running_var:[1160]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[52, 1160, 1, 1]***blocks.5.2.se.conv_reduce.bias:[52]***blocks.5.2.se.conv_expand.weight:[1160, 52, 1, 1]***blocks.5.2.se.conv_expand.bias:[1160]***blocks.5.2.conv_pwl.weight:[208, 1160, 1, 1]***blocks.5.2.bn3.weight:[208]***blocks.5.2.bn3.bias:[208]***blocks.5.2.bn3.running_mean:[208]***blocks.5.2.bn3.running_var:[208]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1182, 208, 1, 1]***blocks.5.3.bn1.weight:[1182]***blocks.5.3.bn1.bias:[1182]***blocks.5.3.bn1.running_mean:[1182]***blocks.5.3.bn1.running_var:[1182]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1182, 1, 5, 5]***blocks.5.3.bn2.weight:[1182]***blocks.5.3.bn2.bias:[1182]***blocks.5.3.bn2.running_mean:[1182]***blocks.5.3.bn2.running_var:[1182]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[52, 1182, 1, 1]***blocks.5.3.se.conv_reduce.bias:[52]***blocks.5.3.se.conv_expand.weight:[1182, 52, 1, 1]***blocks.5.3.se.conv_expand.bias:[1182]***blocks.5.3.conv_pwl.weight:[208, 1182, 1, 1]***blocks.5.3.bn3.weight:[208]***blocks.5.3.bn3.bias:[208]***blocks.5.3.bn3.running_mean:[208]***blocks.5.3.bn3.running_var:[208]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1228, 208, 1, 1]***blocks.5.4.bn1.weight:[1228]***blocks.5.4.bn1.bias:[1228]***blocks.5.4.bn1.running_mean:[1228]***blocks.5.4.bn1.running_var:[1228]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1228, 1, 5, 5]***blocks.5.4.bn2.weight:[1228]***blocks.5.4.bn2.bias:[1228]***blocks.5.4.bn2.running_mean:[1228]***blocks.5.4.bn2.running_var:[1228]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[52, 1228, 1, 1]***blocks.5.4.se.conv_reduce.bias:[52]***blocks.5.4.se.conv_expand.weight:[1228, 52, 1, 1]***blocks.5.4.se.conv_expand.bias:[1228]***blocks.5.4.conv_pwl.weight:[208, 1228, 1, 1]***blocks.5.4.bn3.weight:[208]***blocks.5.4.bn3.bias:[208]***blocks.5.4.bn3.running_mean:[208]***blocks.5.4.bn3.running_var:[208]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1248, 208, 1, 1]***blocks.6.0.bn1.weight:[1248]***blocks.6.0.bn1.bias:[1248]***blocks.6.0.bn1.running_mean:[1248]***blocks.6.0.bn1.running_var:[1248]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1248, 1, 3, 3]***blocks.6.0.bn2.weight:[1248]***blocks.6.0.bn2.bias:[1248]***blocks.6.0.bn2.running_mean:[1248]***blocks.6.0.bn2.running_var:[1248]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[52, 1248, 1, 1]***blocks.6.0.se.conv_reduce.bias:[52]***blocks.6.0.se.conv_expand.weight:[1248, 52, 1, 1]***blocks.6.0.se.conv_expand.bias:[1248]***blocks.6.0.conv_pwl.weight:[352, 1248, 1, 1]***blocks.6.0.bn3.weight:[352]***blocks.6.0.bn3.bias:[352]***blocks.6.0.bn3.running_mean:[352]***blocks.6.0.bn3.running_var:[352]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2112, 352, 1, 1]***blocks.6.1.bn1.weight:[2112]***blocks.6.1.bn1.bias:[2112]***blocks.6.1.bn1.running_mean:[2112]***blocks.6.1.bn1.running_var:[2112]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2112, 1, 3, 3]***blocks.6.1.bn2.weight:[2112]***blocks.6.1.bn2.bias:[2112]***blocks.6.1.bn2.running_mean:[2112]***blocks.6.1.bn2.running_var:[2112]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[88, 2112, 1, 1]***blocks.6.1.se.conv_reduce.bias:[88]***blocks.6.1.se.conv_expand.weight:[2112, 88, 1, 1]***blocks.6.1.se.conv_expand.bias:[2112]***blocks.6.1.conv_pwl.weight:[352, 2112, 1, 1]***blocks.6.1.bn3.weight:[352]***blocks.6.1.bn3.bias:[352]***blocks.6.1.bn3.running_mean:[352]***blocks.6.1.bn3.running_var:[352]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1408, 352, 1, 1]***bn2.weight:[1408]***bn2.bias:[1408]***bn2.running_mean:[1408]***bn2.running_var:[1408]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1408]***classifier.bias:[1000] \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/pruned/efficientnet_b3_pruned.txt b/data_processing/MANIQA/timm/models/pruned/efficientnet_b3_pruned.txt new file mode 100644 index 0000000..4897817 --- /dev/null +++ b/data_processing/MANIQA/timm/models/pruned/efficientnet_b3_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[40, 3, 3, 3]***bn1.weight:[40]***bn1.bias:[40]***bn1.running_mean:[40]***bn1.running_var:[40]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[40, 1, 3, 3]***blocks.0.0.bn1.weight:[40]***blocks.0.0.bn1.bias:[40]***blocks.0.0.bn1.running_mean:[40]***blocks.0.0.bn1.running_var:[40]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[10, 40, 1, 1]***blocks.0.0.se.conv_reduce.bias:[10]***blocks.0.0.se.conv_expand.weight:[40, 10, 1, 1]***blocks.0.0.se.conv_expand.bias:[40]***blocks.0.0.conv_pw.weight:[24, 40, 1, 1]***blocks.0.0.bn2.weight:[24]***blocks.0.0.bn2.bias:[24]***blocks.0.0.bn2.running_mean:[24]***blocks.0.0.bn2.running_var:[24]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[24, 1, 3, 3]***blocks.0.1.bn1.weight:[24]***blocks.0.1.bn1.bias:[24]***blocks.0.1.bn1.running_mean:[24]***blocks.0.1.bn1.running_var:[24]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[6, 24, 1, 1]***blocks.0.1.se.conv_reduce.bias:[6]***blocks.0.1.se.conv_expand.weight:[24, 6, 1, 1]***blocks.0.1.se.conv_expand.bias:[24]***blocks.0.1.conv_pw.weight:[24, 24, 1, 1]***blocks.0.1.bn2.weight:[24]***blocks.0.1.bn2.bias:[24]***blocks.0.1.bn2.running_mean:[24]***blocks.0.1.bn2.running_var:[24]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[27, 24, 1, 1]***blocks.1.0.bn1.weight:[27]***blocks.1.0.bn1.bias:[27]***blocks.1.0.bn1.running_mean:[27]***blocks.1.0.bn1.running_var:[27]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[27, 1, 3, 3]***blocks.1.0.bn2.weight:[27]***blocks.1.0.bn2.bias:[27]***blocks.1.0.bn2.running_mean:[27]***blocks.1.0.bn2.running_var:[27]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[6, 27, 1, 1]***blocks.1.0.se.conv_reduce.bias:[6]***blocks.1.0.se.conv_expand.weight:[27, 6, 1, 1]***blocks.1.0.se.conv_expand.bias:[27]***blocks.1.0.conv_pwl.weight:[12, 27, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[49, 12, 1, 1]***blocks.1.1.bn1.weight:[49]***blocks.1.1.bn1.bias:[49]***blocks.1.1.bn1.running_mean:[49]***blocks.1.1.bn1.running_var:[49]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[49, 1, 3, 3]***blocks.1.1.bn2.weight:[49]***blocks.1.1.bn2.bias:[49]***blocks.1.1.bn2.running_mean:[49]***blocks.1.1.bn2.running_var:[49]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[8, 49, 1, 1]***blocks.1.1.se.conv_reduce.bias:[8]***blocks.1.1.se.conv_expand.weight:[49, 8, 1, 1]***blocks.1.1.se.conv_expand.bias:[49]***blocks.1.1.conv_pwl.weight:[12, 49, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[8, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[8]***blocks.1.2.se.conv_expand.weight:[48, 8, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[83, 12, 1, 1]***blocks.2.0.bn1.weight:[83]***blocks.2.0.bn1.bias:[83]***blocks.2.0.bn1.running_mean:[83]***blocks.2.0.bn1.running_var:[83]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[83, 1, 5, 5]***blocks.2.0.bn2.weight:[83]***blocks.2.0.bn2.bias:[83]***blocks.2.0.bn2.running_mean:[83]***blocks.2.0.bn2.running_var:[83]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[8, 83, 1, 1]***blocks.2.0.se.conv_reduce.bias:[8]***blocks.2.0.se.conv_expand.weight:[83, 8, 1, 1]***blocks.2.0.se.conv_expand.bias:[83]***blocks.2.0.conv_pwl.weight:[40, 83, 1, 1]***blocks.2.0.bn3.weight:[40]***blocks.2.0.bn3.bias:[40]***blocks.2.0.bn3.running_mean:[40]***blocks.2.0.bn3.running_var:[40]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[90, 40, 1, 1]***blocks.2.1.bn1.weight:[90]***blocks.2.1.bn1.bias:[90]***blocks.2.1.bn1.running_mean:[90]***blocks.2.1.bn1.running_var:[90]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[90, 1, 5, 5]***blocks.2.1.bn2.weight:[90]***blocks.2.1.bn2.bias:[90]***blocks.2.1.bn2.running_mean:[90]***blocks.2.1.bn2.running_var:[90]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 90, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[90, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[90]***blocks.2.1.conv_pwl.weight:[40, 90, 1, 1]***blocks.2.1.bn3.weight:[40]***blocks.2.1.bn3.bias:[40]***blocks.2.1.bn3.running_mean:[40]***blocks.2.1.bn3.running_var:[40]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[85, 40, 1, 1]***blocks.2.2.bn1.weight:[85]***blocks.2.2.bn1.bias:[85]***blocks.2.2.bn1.running_mean:[85]***blocks.2.2.bn1.running_var:[85]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[85, 1, 5, 5]***blocks.2.2.bn2.weight:[85]***blocks.2.2.bn2.bias:[85]***blocks.2.2.bn2.running_mean:[85]***blocks.2.2.bn2.running_var:[85]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 85, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[85, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[85]***blocks.2.2.conv_pwl.weight:[40, 85, 1, 1]***blocks.2.2.bn3.weight:[40]***blocks.2.2.bn3.bias:[40]***blocks.2.2.bn3.running_mean:[40]***blocks.2.2.bn3.running_var:[40]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[215, 40, 1, 1]***blocks.3.0.bn1.weight:[215]***blocks.3.0.bn1.bias:[215]***blocks.3.0.bn1.running_mean:[215]***blocks.3.0.bn1.running_var:[215]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[215, 1, 3, 3]***blocks.3.0.bn2.weight:[215]***blocks.3.0.bn2.bias:[215]***blocks.3.0.bn2.running_mean:[215]***blocks.3.0.bn2.running_var:[215]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 215, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[215, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[215]***blocks.3.0.conv_pwl.weight:[93, 215, 1, 1]***blocks.3.0.bn3.weight:[93]***blocks.3.0.bn3.bias:[93]***blocks.3.0.bn3.running_mean:[93]***blocks.3.0.bn3.running_var:[93]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[261, 93, 1, 1]***blocks.3.1.bn1.weight:[261]***blocks.3.1.bn1.bias:[261]***blocks.3.1.bn1.running_mean:[261]***blocks.3.1.bn1.running_var:[261]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[261, 1, 3, 3]***blocks.3.1.bn2.weight:[261]***blocks.3.1.bn2.bias:[261]***blocks.3.1.bn2.running_mean:[261]***blocks.3.1.bn2.running_var:[261]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[24, 261, 1, 1]***blocks.3.1.se.conv_reduce.bias:[24]***blocks.3.1.se.conv_expand.weight:[261, 24, 1, 1]***blocks.3.1.se.conv_expand.bias:[261]***blocks.3.1.conv_pwl.weight:[93, 261, 1, 1]***blocks.3.1.bn3.weight:[93]***blocks.3.1.bn3.bias:[93]***blocks.3.1.bn3.running_mean:[93]***blocks.3.1.bn3.running_var:[93]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[219, 93, 1, 1]***blocks.3.2.bn1.weight:[219]***blocks.3.2.bn1.bias:[219]***blocks.3.2.bn1.running_mean:[219]***blocks.3.2.bn1.running_var:[219]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[219, 1, 3, 3]***blocks.3.2.bn2.weight:[219]***blocks.3.2.bn2.bias:[219]***blocks.3.2.bn2.running_mean:[219]***blocks.3.2.bn2.running_var:[219]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[24, 219, 1, 1]***blocks.3.2.se.conv_reduce.bias:[24]***blocks.3.2.se.conv_expand.weight:[219, 24, 1, 1]***blocks.3.2.se.conv_expand.bias:[219]***blocks.3.2.conv_pwl.weight:[93, 219, 1, 1]***blocks.3.2.bn3.weight:[93]***blocks.3.2.bn3.bias:[93]***blocks.3.2.bn3.running_mean:[93]***blocks.3.2.bn3.running_var:[93]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[254, 93, 1, 1]***blocks.3.3.bn1.weight:[254]***blocks.3.3.bn1.bias:[254]***blocks.3.3.bn1.running_mean:[254]***blocks.3.3.bn1.running_var:[254]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[254, 1, 3, 3]***blocks.3.3.bn2.weight:[254]***blocks.3.3.bn2.bias:[254]***blocks.3.3.bn2.running_mean:[254]***blocks.3.3.bn2.running_var:[254]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[24, 254, 1, 1]***blocks.3.3.se.conv_reduce.bias:[24]***blocks.3.3.se.conv_expand.weight:[254, 24, 1, 1]***blocks.3.3.se.conv_expand.bias:[254]***blocks.3.3.conv_pwl.weight:[93, 254, 1, 1]***blocks.3.3.bn3.weight:[93]***blocks.3.3.bn3.bias:[93]***blocks.3.3.bn3.running_mean:[93]***blocks.3.3.bn3.running_var:[93]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.3.4.conv_pw.weight:[236, 93, 1, 1]***blocks.3.4.bn1.weight:[236]***blocks.3.4.bn1.bias:[236]***blocks.3.4.bn1.running_mean:[236]***blocks.3.4.bn1.running_var:[236]***blocks.3.4.bn1.num_batches_tracked:[]***blocks.3.4.conv_dw.weight:[236, 1, 3, 3]***blocks.3.4.bn2.weight:[236]***blocks.3.4.bn2.bias:[236]***blocks.3.4.bn2.running_mean:[236]***blocks.3.4.bn2.running_var:[236]***blocks.3.4.bn2.num_batches_tracked:[]***blocks.3.4.se.conv_reduce.weight:[24, 236, 1, 1]***blocks.3.4.se.conv_reduce.bias:[24]***blocks.3.4.se.conv_expand.weight:[236, 24, 1, 1]***blocks.3.4.se.conv_expand.bias:[236]***blocks.3.4.conv_pwl.weight:[93, 236, 1, 1]***blocks.3.4.bn3.weight:[93]***blocks.3.4.bn3.bias:[93]***blocks.3.4.bn3.running_mean:[93]***blocks.3.4.bn3.running_var:[93]***blocks.3.4.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[480, 93, 1, 1]***blocks.4.0.bn1.weight:[480]***blocks.4.0.bn1.bias:[480]***blocks.4.0.bn1.running_mean:[480]***blocks.4.0.bn1.running_var:[480]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[480, 1, 5, 5]***blocks.4.0.bn2.weight:[480]***blocks.4.0.bn2.bias:[480]***blocks.4.0.bn2.running_mean:[480]***blocks.4.0.bn2.running_var:[480]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[24, 480, 1, 1]***blocks.4.0.se.conv_reduce.bias:[24]***blocks.4.0.se.conv_expand.weight:[480, 24, 1, 1]***blocks.4.0.se.conv_expand.bias:[480]***blocks.4.0.conv_pwl.weight:[120, 480, 1, 1]***blocks.4.0.bn3.weight:[120]***blocks.4.0.bn3.bias:[120]***blocks.4.0.bn3.running_mean:[120]***blocks.4.0.bn3.running_var:[120]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[235, 120, 1, 1]***blocks.4.1.bn1.weight:[235]***blocks.4.1.bn1.bias:[235]***blocks.4.1.bn1.running_mean:[235]***blocks.4.1.bn1.running_var:[235]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[235, 1, 5, 5]***blocks.4.1.bn2.weight:[235]***blocks.4.1.bn2.bias:[235]***blocks.4.1.bn2.running_mean:[235]***blocks.4.1.bn2.running_var:[235]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[34, 235, 1, 1]***blocks.4.1.se.conv_reduce.bias:[34]***blocks.4.1.se.conv_expand.weight:[235, 34, 1, 1]***blocks.4.1.se.conv_expand.bias:[235]***blocks.4.1.conv_pwl.weight:[120, 235, 1, 1]***blocks.4.1.bn3.weight:[120]***blocks.4.1.bn3.bias:[120]***blocks.4.1.bn3.running_mean:[120]***blocks.4.1.bn3.running_var:[120]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[217, 120, 1, 1]***blocks.4.2.bn1.weight:[217]***blocks.4.2.bn1.bias:[217]***blocks.4.2.bn1.running_mean:[217]***blocks.4.2.bn1.running_var:[217]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[217, 1, 5, 5]***blocks.4.2.bn2.weight:[217]***blocks.4.2.bn2.bias:[217]***blocks.4.2.bn2.running_mean:[217]***blocks.4.2.bn2.running_var:[217]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[34, 217, 1, 1]***blocks.4.2.se.conv_reduce.bias:[34]***blocks.4.2.se.conv_expand.weight:[217, 34, 1, 1]***blocks.4.2.se.conv_expand.bias:[217]***blocks.4.2.conv_pwl.weight:[120, 217, 1, 1]***blocks.4.2.bn3.weight:[120]***blocks.4.2.bn3.bias:[120]***blocks.4.2.bn3.running_mean:[120]***blocks.4.2.bn3.running_var:[120]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[226, 120, 1, 1]***blocks.4.3.bn1.weight:[226]***blocks.4.3.bn1.bias:[226]***blocks.4.3.bn1.running_mean:[226]***blocks.4.3.bn1.running_var:[226]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[226, 1, 5, 5]***blocks.4.3.bn2.weight:[226]***blocks.4.3.bn2.bias:[226]***blocks.4.3.bn2.running_mean:[226]***blocks.4.3.bn2.running_var:[226]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[33, 226, 1, 1]***blocks.4.3.se.conv_reduce.bias:[33]***blocks.4.3.se.conv_expand.weight:[226, 33, 1, 1]***blocks.4.3.se.conv_expand.bias:[226]***blocks.4.3.conv_pwl.weight:[120, 226, 1, 1]***blocks.4.3.bn3.weight:[120]***blocks.4.3.bn3.bias:[120]***blocks.4.3.bn3.running_mean:[120]***blocks.4.3.bn3.running_var:[120]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.4.4.conv_pw.weight:[340, 120, 1, 1]***blocks.4.4.bn1.weight:[340]***blocks.4.4.bn1.bias:[340]***blocks.4.4.bn1.running_mean:[340]***blocks.4.4.bn1.running_var:[340]***blocks.4.4.bn1.num_batches_tracked:[]***blocks.4.4.conv_dw.weight:[340, 1, 5, 5]***blocks.4.4.bn2.weight:[340]***blocks.4.4.bn2.bias:[340]***blocks.4.4.bn2.running_mean:[340]***blocks.4.4.bn2.running_var:[340]***blocks.4.4.bn2.num_batches_tracked:[]***blocks.4.4.se.conv_reduce.weight:[34, 340, 1, 1]***blocks.4.4.se.conv_reduce.bias:[34]***blocks.4.4.se.conv_expand.weight:[340, 34, 1, 1]***blocks.4.4.se.conv_expand.bias:[340]***blocks.4.4.conv_pwl.weight:[120, 340, 1, 1]***blocks.4.4.bn3.weight:[120]***blocks.4.4.bn3.bias:[120]***blocks.4.4.bn3.running_mean:[120]***blocks.4.4.bn3.running_var:[120]***blocks.4.4.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[802, 120, 1, 1]***blocks.5.0.bn1.weight:[802]***blocks.5.0.bn1.bias:[802]***blocks.5.0.bn1.running_mean:[802]***blocks.5.0.bn1.running_var:[802]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[802, 1, 5, 5]***blocks.5.0.bn2.weight:[802]***blocks.5.0.bn2.bias:[802]***blocks.5.0.bn2.running_mean:[802]***blocks.5.0.bn2.running_var:[802]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[34, 802, 1, 1]***blocks.5.0.se.conv_reduce.bias:[34]***blocks.5.0.se.conv_expand.weight:[802, 34, 1, 1]***blocks.5.0.se.conv_expand.bias:[802]***blocks.5.0.conv_pwl.weight:[232, 802, 1, 1]***blocks.5.0.bn3.weight:[232]***blocks.5.0.bn3.bias:[232]***blocks.5.0.bn3.running_mean:[232]***blocks.5.0.bn3.running_var:[232]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1030, 232, 1, 1]***blocks.5.1.bn1.weight:[1030]***blocks.5.1.bn1.bias:[1030]***blocks.5.1.bn1.running_mean:[1030]***blocks.5.1.bn1.running_var:[1030]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1030, 1, 5, 5]***blocks.5.1.bn2.weight:[1030]***blocks.5.1.bn2.bias:[1030]***blocks.5.1.bn2.running_mean:[1030]***blocks.5.1.bn2.running_var:[1030]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[58, 1030, 1, 1]***blocks.5.1.se.conv_reduce.bias:[58]***blocks.5.1.se.conv_expand.weight:[1030, 58, 1, 1]***blocks.5.1.se.conv_expand.bias:[1030]***blocks.5.1.conv_pwl.weight:[232, 1030, 1, 1]***blocks.5.1.bn3.weight:[232]***blocks.5.1.bn3.bias:[232]***blocks.5.1.bn3.running_mean:[232]***blocks.5.1.bn3.running_var:[232]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[924, 232, 1, 1]***blocks.5.2.bn1.weight:[924]***blocks.5.2.bn1.bias:[924]***blocks.5.2.bn1.running_mean:[924]***blocks.5.2.bn1.running_var:[924]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[924, 1, 5, 5]***blocks.5.2.bn2.weight:[924]***blocks.5.2.bn2.bias:[924]***blocks.5.2.bn2.running_mean:[924]***blocks.5.2.bn2.running_var:[924]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[58, 924, 1, 1]***blocks.5.2.se.conv_reduce.bias:[58]***blocks.5.2.se.conv_expand.weight:[924, 58, 1, 1]***blocks.5.2.se.conv_expand.bias:[924]***blocks.5.2.conv_pwl.weight:[232, 924, 1, 1]***blocks.5.2.bn3.weight:[232]***blocks.5.2.bn3.bias:[232]***blocks.5.2.bn3.running_mean:[232]***blocks.5.2.bn3.running_var:[232]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1016, 232, 1, 1]***blocks.5.3.bn1.weight:[1016]***blocks.5.3.bn1.bias:[1016]***blocks.5.3.bn1.running_mean:[1016]***blocks.5.3.bn1.running_var:[1016]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1016, 1, 5, 5]***blocks.5.3.bn2.weight:[1016]***blocks.5.3.bn2.bias:[1016]***blocks.5.3.bn2.running_mean:[1016]***blocks.5.3.bn2.running_var:[1016]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[58, 1016, 1, 1]***blocks.5.3.se.conv_reduce.bias:[58]***blocks.5.3.se.conv_expand.weight:[1016, 58, 1, 1]***blocks.5.3.se.conv_expand.bias:[1016]***blocks.5.3.conv_pwl.weight:[232, 1016, 1, 1]***blocks.5.3.bn3.weight:[232]***blocks.5.3.bn3.bias:[232]***blocks.5.3.bn3.running_mean:[232]***blocks.5.3.bn3.running_var:[232]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1130, 232, 1, 1]***blocks.5.4.bn1.weight:[1130]***blocks.5.4.bn1.bias:[1130]***blocks.5.4.bn1.running_mean:[1130]***blocks.5.4.bn1.running_var:[1130]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1130, 1, 5, 5]***blocks.5.4.bn2.weight:[1130]***blocks.5.4.bn2.bias:[1130]***blocks.5.4.bn2.running_mean:[1130]***blocks.5.4.bn2.running_var:[1130]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[58, 1130, 1, 1]***blocks.5.4.se.conv_reduce.bias:[58]***blocks.5.4.se.conv_expand.weight:[1130, 58, 1, 1]***blocks.5.4.se.conv_expand.bias:[1130]***blocks.5.4.conv_pwl.weight:[232, 1130, 1, 1]***blocks.5.4.bn3.weight:[232]***blocks.5.4.bn3.bias:[232]***blocks.5.4.bn3.running_mean:[232]***blocks.5.4.bn3.running_var:[232]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.5.5.conv_pw.weight:[1266, 232, 1, 1]***blocks.5.5.bn1.weight:[1266]***blocks.5.5.bn1.bias:[1266]***blocks.5.5.bn1.running_mean:[1266]***blocks.5.5.bn1.running_var:[1266]***blocks.5.5.bn1.num_batches_tracked:[]***blocks.5.5.conv_dw.weight:[1266, 1, 5, 5]***blocks.5.5.bn2.weight:[1266]***blocks.5.5.bn2.bias:[1266]***blocks.5.5.bn2.running_mean:[1266]***blocks.5.5.bn2.running_var:[1266]***blocks.5.5.bn2.num_batches_tracked:[]***blocks.5.5.se.conv_reduce.weight:[58, 1266, 1, 1]***blocks.5.5.se.conv_reduce.bias:[58]***blocks.5.5.se.conv_expand.weight:[1266, 58, 1, 1]***blocks.5.5.se.conv_expand.bias:[1266]***blocks.5.5.conv_pwl.weight:[232, 1266, 1, 1]***blocks.5.5.bn3.weight:[232]***blocks.5.5.bn3.bias:[232]***blocks.5.5.bn3.running_mean:[232]***blocks.5.5.bn3.running_var:[232]***blocks.5.5.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1392, 232, 1, 1]***blocks.6.0.bn1.weight:[1392]***blocks.6.0.bn1.bias:[1392]***blocks.6.0.bn1.running_mean:[1392]***blocks.6.0.bn1.running_var:[1392]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1392, 1, 3, 3]***blocks.6.0.bn2.weight:[1392]***blocks.6.0.bn2.bias:[1392]***blocks.6.0.bn2.running_mean:[1392]***blocks.6.0.bn2.running_var:[1392]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[58, 1392, 1, 1]***blocks.6.0.se.conv_reduce.bias:[58]***blocks.6.0.se.conv_expand.weight:[1392, 58, 1, 1]***blocks.6.0.se.conv_expand.bias:[1392]***blocks.6.0.conv_pwl.weight:[384, 1392, 1, 1]***blocks.6.0.bn3.weight:[384]***blocks.6.0.bn3.bias:[384]***blocks.6.0.bn3.running_mean:[384]***blocks.6.0.bn3.running_var:[384]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2301, 384, 1, 1]***blocks.6.1.bn1.weight:[2301]***blocks.6.1.bn1.bias:[2301]***blocks.6.1.bn1.running_mean:[2301]***blocks.6.1.bn1.running_var:[2301]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2301, 1, 3, 3]***blocks.6.1.bn2.weight:[2301]***blocks.6.1.bn2.bias:[2301]***blocks.6.1.bn2.running_mean:[2301]***blocks.6.1.bn2.running_var:[2301]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[96, 2301, 1, 1]***blocks.6.1.se.conv_reduce.bias:[96]***blocks.6.1.se.conv_expand.weight:[2301, 96, 1, 1]***blocks.6.1.se.conv_expand.bias:[2301]***blocks.6.1.conv_pwl.weight:[384, 2301, 1, 1]***blocks.6.1.bn3.weight:[384]***blocks.6.1.bn3.bias:[384]***blocks.6.1.bn3.running_mean:[384]***blocks.6.1.bn3.running_var:[384]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1536, 384, 1, 1]***bn2.weight:[1536]***bn2.bias:[1536]***bn2.running_mean:[1536]***bn2.running_var:[1536]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1536]***classifier.bias:[1000] \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/registry.py b/data_processing/MANIQA/timm/models/registry.py new file mode 100644 index 0000000..f92219b --- /dev/null +++ b/data_processing/MANIQA/timm/models/registry.py @@ -0,0 +1,149 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import sys +import re +import fnmatch +from collections import defaultdict +from copy import deepcopy + +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model default_cfgs + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) + if has_pretrained: + _model_has_pretrained.add(model_name) + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if module: + all_models = list(_module_to_models[module]) + else: + all_models = _model_entrypoints.keys() + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + if exclude_filters: + if not isinstance(exclude_filters, (tuple, list)): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + if pretrained: + models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) + return list(sorted(models, key=_natural_key)) + + +def is_model(model_name): + """ Check if a model name exists + """ + return model_name in _model_entrypoints + + +def model_entrypoint(model_name): + """Fetch a model entrypoint for specified model name + """ + return _model_entrypoints[model_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + assert isinstance(module_names, (tuple, list, set)) + return any(model_name in _module_to_models[n] for n in module_names) + + +def has_model_default_key(model_name, cfg_key): + """ Query model default_cfgs for existence of a specific key. + """ + if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: + return True + return False + + +def is_model_default_key(model_name, cfg_key): + """ Return truthy value for specified model default_cfg key, False if does not exist. + """ + if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): + return True + return False + + +def get_model_default_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if it doesn't exist. + """ + if model_name in _model_default_cfgs: + return _model_default_cfgs[model_name].get(cfg_key, None) + else: + return None + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained diff --git a/data_processing/MANIQA/timm/models/regnet.py b/data_processing/MANIQA/timm/models/regnet.py new file mode 100644 index 0000000..6a38107 --- /dev/null +++ b/data_processing/MANIQA/timm/models/regnet.py @@ -0,0 +1,494 @@ +"""RegNet + +Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 +Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + +Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) +and cleaned up with more descriptive variable names. + +Weights from original impl have been modified +* first layer from BGR -> RGB as most PyTorch models are +* removed training specific dict entries from checkpoints and keep model state_dict only +* remap names to match the ones here + +Hacked together by / Copyright 2020 Ross Wightman +""" +import numpy as np +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath +from .registry import register_model + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), + regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), + regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), + regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), + regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), + regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), + regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), + regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), + regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), + regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), + regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), + regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), + regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), + regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), + regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), + regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), + regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), + regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), + regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), + regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), + regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), + regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), + regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), + regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), +) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'), + regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'), + regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'), + regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'), + regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'), + regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'), + regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'), + regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'), + regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'), + regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), + regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), + regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), + regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), + regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), + regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), + regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), + regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), + regnety_032=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', + crop_pct=1.0, test_input_size=(3, 288, 288)), + regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), + regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), + regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), + regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), + regnety_160=_cfg( + url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository + crop_pct=1.0, test_input_size=(3, 288, 288)), + regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), +) + + +def quantize_float(f, q): + """Converts a float to closest non-zero int divisible by q.""" + return int(round(f / q) * q) + + +def adjust_widths_groups_comp(widths, bottle_ratios, groups): + """Adjusts the compatibility of widths and groups.""" + bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] + return widths, groups + + +def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): + """Generates per block widths from RegNet parameters.""" + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 + widths_cont = np.arange(depth) * width_slope + width_initial + width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) + widths = width_initial * np.power(width_mult, width_exps) + widths = np.round(np.divide(widths, q)) * q + num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages, max_stage, widths_cont + + +class Bottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, + downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + bottleneck_chs = int(round(out_chs * bottleneck_ratio)) + groups = bottleneck_chs // group_width + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvBnAct( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, + groups=groups, **cargs) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels) + else: + self.se = None + cargs['act_layer'] = None + self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + x = self.conv3(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + return x + + +def downsample_conv( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + dilation = dilation if kernel_size > 1 else 1 + return ConvBnAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None) + + +def downsample_avg( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + return nn.Sequential(*[ + pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)]) + + +class RegStage(nn.Module): + """Stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, + block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None): + super(RegStage, self).__init__() + block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + first_dilation = 1 if dilation in (1, 2) else 2 + for i in range(depth): + block_stride = stride if i == 0 else 1 + block_in_chs = in_chs if i == 0 else out_chs + block_dilation = first_dilation if i == 0 else dilation + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + if (block_in_chs != out_chs) or (block_stride != 1): + proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) + else: + proj_block = None + + name = "b{}".format(i + 1) + self.add_module( + name, block_fn( + block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, + downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) + ) + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class RegNet(nn.Module): + """RegNet model. + + Paper: https://arxiv.org/abs/2003.13678 + Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + drop_path_rate=0., zero_init_last_bn=True): + super().__init__() + # TODO add drop block, drop path, anti-aliasing, custom bn/act args + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + + # Construct the stem + stem_width = cfg['stem_width'] + self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) + self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] + + # Construct the stages + prev_width = stem_width + curr_stride = 2 + stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + se_ratio = cfg['se_ratio'] + for i, stage_args in enumerate(stage_params): + stage_name = "s{}".format(i + 1) + self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) + prev_width = stage_args['out_chs'] + curr_stride *= stage_args['stride'] + self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] + + # Construct the head + self.num_features = prev_width + self.head = ClassifierHead( + in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.): + # Generate RegNet ws per block + w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] + widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) + + # Convert to per stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + + # Use the same group width, bottleneck mult and stride for each stage + stage_groups = [cfg['group_w'] for _ in range(num_stages)] + stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] + stage_strides = [] + stage_dilations = [] + net_stride = 2 + dilation = 1 + for _ in range(num_stages): + if net_stride >= output_stride: + dilation *= default_stride + stride = 1 + else: + stride = default_stride + net_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) + + # Adjust the compatibility of ws and gws + stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates'] + stage_params = [ + dict(zip(param_names, params)) for params in + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, + stage_dpr)] + return stage_params + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + for block in list(self.children())[:-1]: + x = block(x) + return x + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if 'model' in state_dict: + # For DeiT trained regnety_160 pretraiend model + state_dict = state_dict['model'] + return state_dict + + +def _create_regnet(variant, pretrained, **kwargs): + return build_model_with_cfg( + RegNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + pretrained_filter_fn=_filter_fn, + **kwargs) + + +@register_model +def regnetx_002(pretrained=False, **kwargs): + """RegNetX-200MF""" + return _create_regnet('regnetx_002', pretrained, **kwargs) + + +@register_model +def regnetx_004(pretrained=False, **kwargs): + """RegNetX-400MF""" + return _create_regnet('regnetx_004', pretrained, **kwargs) + + +@register_model +def regnetx_006(pretrained=False, **kwargs): + """RegNetX-600MF""" + return _create_regnet('regnetx_006', pretrained, **kwargs) + + +@register_model +def regnetx_008(pretrained=False, **kwargs): + """RegNetX-800MF""" + return _create_regnet('regnetx_008', pretrained, **kwargs) + + +@register_model +def regnetx_016(pretrained=False, **kwargs): + """RegNetX-1.6GF""" + return _create_regnet('regnetx_016', pretrained, **kwargs) + + +@register_model +def regnetx_032(pretrained=False, **kwargs): + """RegNetX-3.2GF""" + return _create_regnet('regnetx_032', pretrained, **kwargs) + + +@register_model +def regnetx_040(pretrained=False, **kwargs): + """RegNetX-4.0GF""" + return _create_regnet('regnetx_040', pretrained, **kwargs) + + +@register_model +def regnetx_064(pretrained=False, **kwargs): + """RegNetX-6.4GF""" + return _create_regnet('regnetx_064', pretrained, **kwargs) + + +@register_model +def regnetx_080(pretrained=False, **kwargs): + """RegNetX-8.0GF""" + return _create_regnet('regnetx_080', pretrained, **kwargs) + + +@register_model +def regnetx_120(pretrained=False, **kwargs): + """RegNetX-12GF""" + return _create_regnet('regnetx_120', pretrained, **kwargs) + + +@register_model +def regnetx_160(pretrained=False, **kwargs): + """RegNetX-16GF""" + return _create_regnet('regnetx_160', pretrained, **kwargs) + + +@register_model +def regnetx_320(pretrained=False, **kwargs): + """RegNetX-32GF""" + return _create_regnet('regnetx_320', pretrained, **kwargs) + + +@register_model +def regnety_002(pretrained=False, **kwargs): + """RegNetY-200MF""" + return _create_regnet('regnety_002', pretrained, **kwargs) + + +@register_model +def regnety_004(pretrained=False, **kwargs): + """RegNetY-400MF""" + return _create_regnet('regnety_004', pretrained, **kwargs) + + +@register_model +def regnety_006(pretrained=False, **kwargs): + """RegNetY-600MF""" + return _create_regnet('regnety_006', pretrained, **kwargs) + + +@register_model +def regnety_008(pretrained=False, **kwargs): + """RegNetY-800MF""" + return _create_regnet('regnety_008', pretrained, **kwargs) + + +@register_model +def regnety_016(pretrained=False, **kwargs): + """RegNetY-1.6GF""" + return _create_regnet('regnety_016', pretrained, **kwargs) + + +@register_model +def regnety_032(pretrained=False, **kwargs): + """RegNetY-3.2GF""" + return _create_regnet('regnety_032', pretrained, **kwargs) + + +@register_model +def regnety_040(pretrained=False, **kwargs): + """RegNetY-4.0GF""" + return _create_regnet('regnety_040', pretrained, **kwargs) + + +@register_model +def regnety_064(pretrained=False, **kwargs): + """RegNetY-6.4GF""" + return _create_regnet('regnety_064', pretrained, **kwargs) + + +@register_model +def regnety_080(pretrained=False, **kwargs): + """RegNetY-8.0GF""" + return _create_regnet('regnety_080', pretrained, **kwargs) + + +@register_model +def regnety_120(pretrained=False, **kwargs): + """RegNetY-12GF""" + return _create_regnet('regnety_120', pretrained, **kwargs) + + +@register_model +def regnety_160(pretrained=False, **kwargs): + """RegNetY-16GF""" + return _create_regnet('regnety_160', pretrained, **kwargs) + + +@register_model +def regnety_320(pretrained=False, **kwargs): + """RegNetY-32GF""" + return _create_regnet('regnety_320', pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/res2net.py b/data_processing/MANIQA/timm/models/res2net.py new file mode 100644 index 0000000..282baba --- /dev/null +++ b/data_processing/MANIQA/timm/models/res2net.py @@ -0,0 +1,216 @@ +""" Res2Net and Res2NeXt +Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/ +Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .resnet import ResNet + +__all__ = [] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'res2net50_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'), + 'res2net50_48w_2s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'), + 'res2net50_14w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'), + 'res2net50_26w_6s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'), + 'res2net50_26w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'), + 'res2net101_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'), + 'res2next50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'), +} + + +class Bottle2neck(nn.Module): + """ Res2Net/Res2NeXT Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): + super(Bottle2neck, self).__init__() + self.scale = scale + self.is_first = stride > 1 or downsample is not None + self.num_scales = max(1, scale - 1) + width = int(math.floor(planes * (base_width / 64.0))) * cardinality + self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) + self.bn1 = norm_layer(width * scale) + + convs = [] + bns = [] + for i in range(self.num_scales): + convs.append(nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) + bns.append(norm_layer(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + # FIXME this should probably have count_include_pad=False, but hurts original weights + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + else: + self.pool = None + + self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = attn_layer(outplanes) if attn_layer is not None else None + + self.relu = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + sp = spx[0] # redundant, for torchscript + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + if i == 0 or self.is_first: + sp = spx[i] + else: + sp = sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + if self.pool is not None: + # self.is_first == True, None check for torchscript + spo.append(self.pool(spx[-1])) + else: + spo.append(spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out += shortcut + out = self.relu(out) + + return out + + +def _create_res2net(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def res2net50_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net50_26w_4s', pretrained, **model_args) + + +@register_model +def res2net101_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-101 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net101_26w_4s', pretrained, **model_args) + + +@register_model +def res2net50_26w_6s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w6s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) + return _create_res2net('res2net50_26w_6s', pretrained, **model_args) + + +@register_model +def res2net50_26w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) + + +@register_model +def res2net50_48w_2s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 48w2s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) + return _create_res2net('res2net50_48w_2s', pretrained, **model_args) + + +@register_model +def res2net50_14w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 14w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_14w_8s', pretrained, **model_args) + + +@register_model +def res2next50(pretrained=False, **kwargs): + """Construct Res2NeXt-50 4s + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2next50', pretrained, **model_args) diff --git a/data_processing/MANIQA/timm/models/resnest.py b/data_processing/MANIQA/timm/models/resnest.py new file mode 100644 index 0000000..31eebd8 --- /dev/null +++ b/data_processing/MANIQA/timm/models/resnest.py @@ -0,0 +1,237 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SplitAttn +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1.0', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest14d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), + 'resnest26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), + 'resnest50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnest200e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), + 'resnest269e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), + 'resnest50d_4s2x40d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', + interpolation='bicubic'), + 'resnest50d_1s4x24d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', + interpolation='bicubic') +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + self.drop_block = drop_block + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttn( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = nn.Identity() + self.act2 = nn.Identity() + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block is not None: + out = self.drop_block(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out += shortcut + out = self.act3(out) + return out + + +def _create_resnest(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def resnest14d(pretrained=False, **kwargs): + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[1, 1, 1, 1], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest26d(pretrained=False, **kwargs): + """ ResNeSt-26d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[2, 2, 2, 2], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d(pretrained=False, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest101e(pretrained=False, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 23, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest200e(pretrained=False, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 24, 36, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest269e(pretrained=False, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 30, 48, 8], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_4s2x40d(pretrained=False, **kwargs): + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_1s4x24d(pretrained=False, **kwargs): + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/resnet.py b/data_processing/MANIQA/timm/models/resnet.py new file mode 100644 index 0000000..f0ce507 --- /dev/null +++ b/data_processing/MANIQA/timm/models/resnet.py @@ -0,0 +1,1547 @@ +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. + +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman + +Copyright 2019, Ross Wightman +""" +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier +from .registry import register_model + +__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # ResNet and Wide ResNet + 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), + 'resnet18d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet34d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet26': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth', + interpolation='bicubic'), + 'resnet26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet26t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnet50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet50t': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnet101d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1h-dc400468.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnet152d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet200': _cfg(url='', interpolation='bicubic'), + 'resnet200d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), + 'tv_resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), + 'wide_resnet50_2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth', + interpolation='bicubic'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + + # ResNets w/ alternative norm layers + 'resnet50_gn': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth', + crop_pct=0.94, interpolation='bicubic'), + + # ResNeXt + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnext50d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), + 'resnext101_64x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), + + # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags + # from https://github.com/facebookresearch/WSL-Images + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), + 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), + 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), + 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'), + + # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ssl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'), + 'ssl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'), + 'ssl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'), + 'ssl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'), + 'ssl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'), + 'ssl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'), + + # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'swsl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'), + 'swsl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'), + 'swsl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'), + 'swsl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'), + 'swsl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'), + 'swsl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py + 'seresnet18': _cfg( + url='', + interpolation='bicubic'), + 'seresnet34': _cfg( + url='', + interpolation='bicubic'), + 'seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth', + interpolation='bicubic'), + 'seresnet50t': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnet101': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320) + ), + 'seresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'seresnet269d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + + + # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py + 'seresnext26d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext26t_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth', + interpolation='bicubic'), + 'seresnext101_32x4d': _cfg( + url='', + interpolation='bicubic'), + 'seresnext101_32x8d': _cfg( + url='', + interpolation='bicubic'), + 'senet154': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + + # Efficient Channel Attention ResNets + 'ecaresnet26t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnetlight': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', + interpolation='bicubic'), + 'ecaresnet50d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnet101d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnet101d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'ecaresnet269d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), + crop_pct=1.0, test_input_size=(3, 352, 352)), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnext50t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + + # ResNets with anti-aliasing blur pool + 'resnetblur18': _cfg( + interpolation='bicubic'), + 'resnetblur50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', + interpolation='bicubic'), + 'resnetblur50d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'resnetblur101d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'resnetaa50d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'resnetaa101d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'seresnetaa50d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + + # ResNet-RS models + 'resnetrs50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth', + input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth', + input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs270': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs350': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs420': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416), + interpolation='bicubic', first_conv='conv1.0'), +} + + +def get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return None + return aa_layer(stride) if issubclass(aa_layer, nn.AvgPool2d) else aa_layer(channels=channels, stride=stride) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock does not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, + dilation=first_dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa) + + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa) + + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + +def drop_blocks(drop_block_rate=0.): + return [ + None, None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + + +def make_blocks( + block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, planes, stride, downsample, first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + +class ResNet(nn.Module): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants (the same modifications can be used in SE/ResNeXt models as well): + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) + * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) + * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample + * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int + Numbers of layers in each block + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + cardinality : int, default 1 + Number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64 + Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + stem_width : int, default 64 + Number of channels in stem convolutions + stem_type : str, default '' + The type of stem: + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + block_reduce_first: int, default 1 + Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 + down_kernel_size: int, default 1 + Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + output_stride : int, default 32 + Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. + act_layer : nn.Module, activation layer + norm_layer : nn.Module, normalization layer + aa_layer : nn.Module, anti-aliasing layer + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, block, layers, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, + output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + block_args = block_args or dict() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + super(ResNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 + if deep_stem: + stem_chs = (stem_width, stem_width) + if 'tiered' in stem_type: + stem_chs = (3 * (stem_width // 4), stem_width) + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), + norm_layer(stem_chs[0]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), + norm_layer(stem_chs[1]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) + self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] + + # Stem pooling. The name 'maxpool' remains for weight compatibility. + if replace_stem_pool: + self.maxpool = nn.Sequential(*filter(None, [ + nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), + create_aa(aa_layer, channels=inplanes, stride=2), + norm_layer(inplanes), + act_layer(inplace=True) + ])) + else: + if aa_layer is not None: + if issubclass(aa_layer, nn.AvgPool2d): + self.maxpool = aa_layer(2) + else: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Feature Blocks + channels = [64, 128, 256, 512] + stage_modules, stage_feature_info = make_blocks( + block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, + output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, + down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) + + # Head (Pooling and Classifier) + self.num_features = 512 * block.expansion + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet18', pretrained, **model_args) + + +@register_model +def resnet18d(pretrained=False, **kwargs): + """Constructs a ResNet-18-D model. + """ + model_args = dict( + block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet18d', pretrained, **model_args) + + +@register_model +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet34', pretrained, **model_args) + + +@register_model +def resnet34d(pretrained=False, **kwargs): + """Constructs a ResNet-34-D model. + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet34d', pretrained, **model_args) + + +@register_model +def resnet26(pretrained=False, **kwargs): + """Constructs a ResNet-26 model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet26', pretrained, **model_args) + + +@register_model +def resnet26t(pretrained=False, **kwargs): + """Constructs a ResNet-26-T model. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet26t', pretrained, **model_args) + + +@register_model +def resnet26d(pretrained=False, **kwargs): + """Constructs a ResNet-26-D model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet26d', pretrained, **model_args) + + +@register_model +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50', pretrained, **model_args) + + +@register_model +def resnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet50d', pretrained, **model_args) + + +@register_model +def resnet50t(pretrained=False, **kwargs): + """Constructs a ResNet-50-T model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet50t', pretrained, **model_args) + + +@register_model +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('resnet101', pretrained, **model_args) + + +@register_model +def resnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet101d', pretrained, **model_args) + + +@register_model +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('resnet152', pretrained, **model_args) + + +@register_model +def resnet152d(pretrained=False, **kwargs): + """Constructs a ResNet-152-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet152d', pretrained, **model_args) + + +@register_model +def resnet200(pretrained=False, **kwargs): + """Constructs a ResNet-200 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs) + return _create_resnet('resnet200', pretrained, **model_args) + + +@register_model +def resnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet200d', pretrained, **model_args) + + +@register_model +def tv_resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet34', pretrained, **model_args) + + +@register_model +def tv_resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet50', pretrained, **model_args) + + +@register_model +def tv_resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('tv_resnet101', pretrained, **model_args) + + +@register_model +def tv_resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('tv_resnet152', pretrained, **model_args) + + +@register_model +def wide_resnet50_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet50_2', pretrained, **model_args) + + +@register_model +def wide_resnet101_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-101-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet101_2', pretrained, **model_args) + + +@register_model +def resnet50_gn(pretrained=False, **kwargs): + """Constructs a ResNet-50 model w/ GroupNorm + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args) + + +@register_model +def resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext50_32x4d', pretrained, **model_args) + + +@register_model +def resnext50d_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnext50d_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext101_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x8d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('resnext101_32x8d', pretrained, **model_args) + + +@register_model +def resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt101-64x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('resnext101_64x4d', pretrained, **model_args) + + +@register_model +def tv_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x32d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) + return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x48d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) + return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) + + +@register_model +def ssl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('ssl_resnet18', pretrained, **model_args) + + +@register_model +def ssl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('ssl_resnet50', pretrained, **model_args) + + +@register_model +def ssl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def swsl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('swsl_resnet18', pretrained, **model_args) + + +@register_model +def swsl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('swsl_resnet50', pretrained, **model_args) + + +@register_model +def swsl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ecaresnet26t(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet26t', pretrained, **model_args) + + +@register_model +def ecaresnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d', pretrained, **model_args) + + +@register_model +def resnetrs50(pretrained=False, **kwargs): + """Constructs a ResNet-RS-50 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + """Constructs a ResNet-RS-101 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + """Constructs a ResNet-RS-152 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + """Constructs a ResNet-RS-200 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + """Constructs a ResNet-RS-270 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + """Constructs a ResNet-RS-350 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + """Constructs a ResNet-RS-420 model + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) + + +@register_model +def ecaresnet50d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnet50t(pretrained=False, **kwargs): + """Constructs an ECA-ResNet-50-T model. + Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50t', pretrained, **model_args) + + +@register_model +def ecaresnetlight(pretrained=False, **kwargs): + """Constructs a ResNet-50-D light model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnetlight', pretrained, **model_args) + + +@register_model +def ecaresnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d', pretrained, **model_args) + + +@register_model +def ecaresnet101d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet200d', pretrained, **model_args) + + +@register_model +def ecaresnet269d(pretrained=False, **kwargs): + """Constructs a ResNet-269-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet269d', pretrained, **model_args) + + +@register_model +def ecaresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def ecaresnext50t_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-50-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) + + +@register_model +def resnetblur18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) + + +@register_model +def resnetblur50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def resnetblur50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur50d', pretrained, **model_args) + + +@register_model +def resnetblur101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur101d', pretrained, **model_args) + + +@register_model +def resnetaa50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa50d', pretrained, **model_args) + + +@register_model +def resnetaa101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa101d', pretrained, **model_args) + + +@register_model +def seresnetaa50d(pretrained=False, **kwargs): + """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnetaa50d', pretrained, **model_args) + + +@register_model +def seresnet18(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet18', pretrained, **model_args) + + +@register_model +def seresnet34(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet34', pretrained, **model_args) + + +@register_model +def seresnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50', pretrained, **model_args) + + +@register_model +def seresnet50t(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50t', pretrained, **model_args) + + +@register_model +def seresnet101(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet101', pretrained, **model_args) + + +@register_model +def seresnet152(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152', pretrained, **model_args) + + +@register_model +def seresnet152d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152d', pretrained, **model_args) + + +@register_model +def seresnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet200d', pretrained, **model_args) + + +@register_model +def seresnet269d(pretrained=False, **kwargs): + """Constructs a ResNet-269-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet269d', pretrained, **model_args) + + +@register_model +def seresnext26d_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-D model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNet-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-T model. + NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note + so keeping this def for backwards compat with any uses out there. Old 't' model is lost. + """ + return seresnext26t_32x4d(pretrained=pretrained, **kwargs) + + +@register_model +def seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x8d', pretrained, **model_args) + + +@register_model +def senet154(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('senet154', pretrained, **model_args) diff --git a/data_processing/MANIQA/timm/models/resnetv2.py b/data_processing/MANIQA/timm/models/resnetv2.py new file mode 100644 index 0000000..e38eaf5 --- /dev/null +++ b/data_processing/MANIQA/timm/models/resnetv2.py @@ -0,0 +1,672 @@ +"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization. + +A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code +at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have +been included here as pretrained models from their original .NPZ checkpoints. + +Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and +extra padding support to allow porting of official Hybrid ResNet pretrained weights from +https://github.com/google-research/vision_transformer + +Thanks to the Google team for the above two repositories and associated papers: +* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370 +* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929 +* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + +Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. +""" +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict # pylint: disable=g-importing-member + +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .registry import register_model +from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\ + ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + # pretrained on imagenet21k, finetuned on imagenet1k + 'resnetv2_50x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_50x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_101x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_101x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_152x2_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_152x4_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz', + input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480? + + # trained on imagenet-21k + 'resnetv2_50x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz', + num_classes=21843), + 'resnetv2_50x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz', + num_classes=21843), + 'resnetv2_101x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz', + num_classes=21843), + 'resnetv2_101x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz', + num_classes=21843), + 'resnetv2_152x2_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz', + num_classes=21843), + 'resnetv2_152x4_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', + num_classes=21843), + + 'resnetv2_50x1_bit_distilled': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', + interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', + interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher_384': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), + + 'resnetv2_50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnetv2_50d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50t': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth', + interpolation='bicubic', crop_pct=0.95), + 'resnetv2_101d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_152': _cfg( + interpolation='bicubic'), + 'resnetv2_152d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + + 'resnetv2_50d_gn': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evob': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evos': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), +} + + +def make_div(v, divisor=8): + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm3 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + nn.init.zeros_(self.conv3.weight) + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.conv3(self.norm3(x)) + x = self.drop_path(x) + return x + shortcut + + +class Bottleneck(nn.Module): + """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT. + """ + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + act_layer = act_layer or nn.ReLU + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, preact=False, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm1 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm2 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.norm3 = norm_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.act3 = act_layer(inplace=True) + + def zero_init_last(self): + nn.init.zeros_(self.norm3.weight) + + def forward(self, x): + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + # residual + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.conv3(x) + x = self.norm3(x) + x = self.drop_path(x) + x = self.act3(x + shortcut) + return x + + +class DownsampleConv(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, + conv_layer=None, norm_layer=None): + super(DownsampleConv, self).__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, + preact=True, conv_layer=None, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ResNetStage(nn.Module): + """ResNet Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, + avg_down=False, block_dpr=None, block_fn=PreActBottleneck, + act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): + super(ResNetStage, self).__init__() + first_dilation = 1 if dilation in (1, 2) else 2 + layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) + proj_layer = DownsampleAvg if avg_down else DownsampleConv + prev_chs = in_chs + self.blocks = nn.Sequential() + for block_idx in range(depth): + drop_path_rate = block_dpr[block_idx] if block_dpr else 0. + stride = stride if block_idx == 0 else 1 + self.blocks.add_module(str(block_idx), block_fn( + prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, + first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, + **layer_kwargs, **block_kwargs)) + prev_chs = out_chs + first_dilation = dilation + proj_layer = None + + def forward(self, x): + x = self.blocks(x) + return x + + +def is_stem_deep(stem_type): + return any([s in stem_type for s in ('deep', 'tiered')]) + + +def create_resnetv2_stem( + in_chs, out_chs=64, stem_type='', preact=True, + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): + stem = OrderedDict() + assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') + + # NOTE conv padding mode can be changed by overriding the conv_layer def + if is_stem_deep(stem_type): + # A 3 deep 3x3 conv stack as in ResNet V1D models + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets + stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2) + stem['norm1'] = norm_layer(stem_chs[0]) + stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) + stem['norm2'] = norm_layer(stem_chs[1]) + stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1) + if not preact: + stem['norm3'] = norm_layer(out_chs) + else: + # The usual 7x7 stem conv + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + if not preact: + stem['norm'] = norm_layer(out_chs) + + if 'fixed' in stem_type: + # 'fixed' SAME padding approximation that is used in BiT models + stem['pad'] = nn.ConstantPad2d(1, 0.) + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + elif 'same' in stem_type: + # full, input size based 'SAME' padding, used in ViT Hybrid model + stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same') + else: + # the usual PyTorch symmetric padding + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + return nn.Sequential(stem) + + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode. + """ + + def __init__( + self, layers, channels=(256, 512, 1024, 2048), + num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, + act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., drop_path_rate=0., zero_init_last=False): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + wf = width_factor + + self.feature_info = [] + stem_chs = make_div(stem_chs * wf) + self.stem = create_resnetv2_stem( + in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) + + prev_chs = stem_chs + curr_stride = 4 + dilation = 1 + block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + block_fn = PreActBottleneck if preact else Bottleneck + self.stages = nn.Sequential() + for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): + out_chs = make_div(c * wf) + stride = 1 if stage_idx == 0 else 2 + if curr_stride >= output_stride: + dilation *= stride + stride = 1 + stage = ResNetStage( + prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, + act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn) + prev_chs = out_chs + curr_stride *= stride + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] + self.stages.add_module(str(stage_idx), stage) + + self.num_features = prev_chs + self.norm = norm_layer(self.num_features) if preact else nn.Identity() + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + self.init_weights(zero_init_last=zero_init_last) + + def init_weights(self, zero_init_last=True): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + _load_weights(self, checkpoint_path, prefix) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Module, name: str = '', zero_init_last=True): + if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() + + +@torch.no_grad() +def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) + model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) + if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \ + model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) + model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel'])) + block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel'])) + block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel'])) + block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{cname}/kernel'] + block.downsample.conv.weight.copy_(t2p(w)) + + +def _create_resnetv2(variant, pretrained=False, **kwargs): + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + ResNetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + pretrained_custom_load='_bit' in variant, + **kwargs) + + +def _create_resnetv2_bit(variant, pretrained=False, **kwargs): + return _create_resnetv2( + variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs) + + +@register_model +def resnetv2_50x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_50x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_101x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_101x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_152x2_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x4_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) + + +@register_model +def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 6, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 23, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 23, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 8, 36, 3], width_factor=4, **kwargs) + + +@register_model +def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs): + """ ResNetV2-50x1-BiT Distilled + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs): + """ ResNetV2-152x2-BiT Teacher + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs): + """ ResNetV2-152xx-BiT Teacher @ 384x384 + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_50(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_50d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50t(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50t', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='tiered', avg_down=True, **kwargs) + + +@register_model +def resnetv2_101(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_101d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101d', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_152(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_152d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152d', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +# Experimental configs (may change / be removed) + +@register_model +def resnetv2_50d_gn(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_gn', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evob(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evob', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evos(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evos', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, + stem_type='deep', avg_down=True, **kwargs) diff --git a/data_processing/MANIQA/timm/models/rexnet.py b/data_processing/MANIQA/timm/models/rexnet.py new file mode 100644 index 0000000..f27ce5d --- /dev/null +++ b/data_processing/MANIQA/timm/models/rexnet.py @@ -0,0 +1,239 @@ +""" ReXNet + +A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` - +https://arxiv.org/abs/2007.00992 + +Adapted from original impl at https://github.com/clovaai/rexnet +Copyright (c) 2020-present NAVER Corp. MIT license + +Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman +Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn +from functools import partial +from math import ceil + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule +from .registry import register_model +from .efficientnet_builder import efficientnet_init_weights + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + rexnet_100=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'), + rexnet_130=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'), + rexnet_150=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'), + rexnet_200=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'), + rexnetr_100=_cfg( + url=''), + rexnetr_130=_cfg( + url=''), + rexnetr_150=_cfg( + url=''), + rexnetr_200=_cfg( + url=''), +) + +SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) + + +class LinearBottleneck(nn.Module): + def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, + act_layer='swish', dw_act_layer='relu6', drop_path=None): + super(LinearBottleneck, self).__init__() + self.use_shortcut = stride == 1 and in_chs <= out_chs + self.in_channels = in_chs + self.out_channels = out_chs + + if exp_ratio != 1.: + dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) + self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer) + else: + dw_chs = in_chs + self.conv_exp = None + + self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + if se_ratio > 0: + self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) + else: + self.se = None + self.act_dw = create_act_layer(dw_act_layer) + + self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) + self.drop_path = drop_path + + def feat_channels(self, exp=False): + return self.conv_dw.out_channels if exp else self.out_channels + + def forward(self, x): + shortcut = x + if self.conv_exp is not None: + x = self.conv_exp(x) + x = self.conv_dw(x) + if self.se is not None: + x = self.se(x) + x = self.act_dw(x) + x = self.conv_pwl(x) + if self.use_shortcut: + if self.drop_path is not None: + x = self.drop_path(x) + x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1) + return x + + +def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1): + layers = [1, 2, 2, 3, 3, 5] + strides = [1, 2, 2, 2, 1, 2] + layers = [ceil(element * depth_mult) for element in layers] + strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) + exp_ratios = [1] * layers[0] + [6] * sum(layers[1:]) + depth = sum(layers[:]) * 3 + base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs + + # The following channel configuration is a simple instance to make each layer become an expand layer. + out_chs_list = [] + for i in range(depth // 3): + out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div)) + base_chs += final_chs / (depth // 3 * 1.0) + + se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:]) + + return list(zip(out_chs_list, exp_ratios, strides, se_ratios)) + + +def _build_blocks( + block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.): + feat_chs = [prev_chs] + feature_info = [] + curr_stride = 2 + features = [] + num_blocks = len(block_cfg) + for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg): + if stride > 1: + fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] + curr_stride *= stride + block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule + drop_path = DropPath(block_dpr) if block_dpr > 0. else None + features.append(LinearBottleneck( + in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio, + ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path)) + prev_chs = chs + feat_chs += [features[-1].feat_channels()] + pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')] + features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer)) + return features, feature_info + + +class ReXNetV1(nn.Module): + def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, + initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., + ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.): + super(ReXNetV1, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + + assert output_stride == 32 # FIXME support dilation + stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 + stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) + self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) + + block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) + features, self.feature_info = _build_blocks( + block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate) + self.num_features = features[-1].out_channels + self.features = nn.Sequential(*features) + + self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) + + efficientnet_init_weights(self) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_rexnet(variant, pretrained, **kwargs): + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + ReXNetV1, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + **kwargs) + + +@register_model +def rexnet_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x""" + return _create_rexnet('rexnet_100', pretrained, **kwargs) + + +@register_model +def rexnet_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x""" + return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs) + + +@register_model +def rexnet_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x""" + return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs) + + +@register_model +def rexnet_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x""" + return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) + + +@register_model +def rexnetr_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs) + + +@register_model +def rexnetr_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs) + + +@register_model +def rexnetr_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs) + + +@register_model +def rexnetr_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs) diff --git a/data_processing/MANIQA/timm/models/selecsls.py b/data_processing/MANIQA/timm/models/selecsls.py new file mode 100644 index 0000000..1f3379d --- /dev/null +++ b/data_processing/MANIQA/timm/models/selecsls.py @@ -0,0 +1,362 @@ +"""PyTorch SelecSLS Net example for ImageNet Classification +License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode) +Author: Dushyant Mehta (@mehtadushy) + +SelecSLS (core) Network Architecture as proposed in "XNect: Real-time Multi-person 3D +Human Pose Estimation with a Single RGB Camera, Mehta et al." +https://arxiv.org/abs/1907.00837 + +Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models +and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'selecsls42': _cfg( + url='', + interpolation='bicubic'), + 'selecsls42b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth', + interpolation='bicubic'), + 'selecsls60': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth', + interpolation='bicubic'), + 'selecsls60b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth', + interpolation='bicubic'), + 'selecsls84': _cfg( + url='', + interpolation='bicubic'), +} + + +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +class SelectSeq(nn.Module): + def __init__(self, mode='index', index=0): + super(SelectSeq, self).__init__() + self.mode = mode + self.index = index + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor]) -> (torch.Tensor) + pass + + def forward(self, x) -> torch.Tensor: + if self.mode == 'index': + return x[self.index] + else: + return torch.cat(x, dim=1) + + +def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): + if padding is None: + padding = ((stride - 1) + dilation * (k - 1)) // 2 + return nn.Sequential( + nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_chs), + nn.ReLU(inplace=True) + ) + + +class SelecSLSBlock(nn.Module): + def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1): + super(SelecSLSBlock, self).__init__() + self.stride = stride + self.is_first = is_first + assert stride in [1, 2] + + # Process input with 4 conv blocks with the same number of input and output channels + self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation) + self.conv2 = conv_bn(mid_chs, mid_chs, 1) + self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1) + self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(x, list): + x = [x] + assert len(x) in [1, 2] + + d1 = self.conv1(x[0]) + d2 = self.conv3(self.conv2(d1)) + d3 = self.conv5(self.conv4(d2)) + if self.is_first: + out = self.conv6(torch.cat([d1, d2, d3], 1)) + return [out, out] + else: + return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]] + + +class SelecSLS(nn.Module): + """SelecSLS42 / SelecSLS60 / SelecSLS84 + + Parameters + ---------- + cfg : network config dictionary specifying block type, feature, and head args + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(SelecSLS, self).__init__() + + self.stem = conv_bn(in_chans, 32, stride=2) + self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way + self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) + self.num_features = cfg['num_features'] + self.feature_info = cfg['feature_info'] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + x = self.head(self.from_seq(x)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_selecsls(variant, pretrained, **kwargs): + cfg = {} + feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] + if variant.startswith('selecsls42'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 144, 144, True, 2), + (144, 144, 144, 288, False, 1), + (288, 0, 304, 304, True, 2), + (304, 304, 304, 480, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.3'), + dict(num_chs=480, reduction=16, module='features.5'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls42b': + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant.startswith('selecsls60'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 128, 128, True, 2), + (128, 128, 128, 128, False, 1), + (128, 128, 128, 288, False, 1), + (288, 0, 288, 288, True, 2), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 416, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.4'), + dict(num_chs=416, reduction=16, module='features.8'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls60b': + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant == 'selecsls84': + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 144, False, 1), + (144, 0, 144, 144, True, 2), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 304, False, 1), + (304, 0, 304, 304, True, 2), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 512, False, 1), + ] + feature_info.extend([ + dict(num_chs=144, reduction=4, module='features.1'), + dict(num_chs=304, reduction=8, module='features.6'), + dict(num_chs=512, reduction=16, module='features.12'), + ]) + # Head can be replaced with alternative configurations depending on the problem + cfg['head'] = [ + (512, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 3, 1), + ] + cfg['num_features'] = 1280 + feature_info.extend([ + dict(num_chs=1024, reduction=32, module='head.1'), + dict(num_chs=1280, reduction=64, module='head.3') + ]) + else: + raise ValueError('Invalid net configuration ' + variant + ' !!!') + cfg['feature_info'] = feature_info + + # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? + return build_model_with_cfg( + SelecSLS, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), + **kwargs) + + +@register_model +def selecsls42(pretrained=False, **kwargs): + """Constructs a SelecSLS42 model. + """ + return _create_selecsls('selecsls42', pretrained, **kwargs) + + +@register_model +def selecsls42b(pretrained=False, **kwargs): + """Constructs a SelecSLS42_B model. + """ + return _create_selecsls('selecsls42b', pretrained, **kwargs) + + +@register_model +def selecsls60(pretrained=False, **kwargs): + """Constructs a SelecSLS60 model. + """ + return _create_selecsls('selecsls60', pretrained, **kwargs) + + +@register_model +def selecsls60b(pretrained=False, **kwargs): + """Constructs a SelecSLS60_B model. + """ + return _create_selecsls('selecsls60b', pretrained, **kwargs) + + +@register_model +def selecsls84(pretrained=False, **kwargs): + """Constructs a SelecSLS84 model. + """ + return _create_selecsls('selecsls84', pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/senet.py b/data_processing/MANIQA/timm/models/senet.py new file mode 100644 index 0000000..3d0ba7b --- /dev/null +++ b/data_processing/MANIQA/timm/models/senet.py @@ -0,0 +1,467 @@ +""" +SEResNet implementation from Cadene's pretrained models +https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py +Additional credit to https://github.com/creafz + +Original model: https://github.com/hujie-frank/SENet + +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate +support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. +""" +import math +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SENet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + **kwargs + } + + +default_cfgs = { + 'legacy_senet154': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), + 'legacy_seresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', + interpolation='bicubic'), + 'legacy_seresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), + 'legacy_seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'legacy_seresnet101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'legacy_seresnet152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'legacy_seresnext26_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), + 'legacy_seresnext50_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), + 'legacy_seresnext101_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'), +} + + +def _weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = x.mean((2, 3), keepdim=True) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, planes * 4, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d( + planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d( + inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, drop_rate=0.2, + in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, + downsample_padding=0, num_classes=1000, global_pool='avg'): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + self.num_classes = num_classes + self.drop_rate = drop_rate + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d( + in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`. + self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')] + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] + self.num_features = 512 * block.expansion + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + _weight_init(m) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, + stride=stride, padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.layer0(x) + x = self.pool0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.logits(x) + return x + + +def _create_senet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + SENet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def legacy_seresnet18(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet18', pretrained, **model_args) + + +@register_model +def legacy_seresnet34(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet34', pretrained, **model_args) + + +@register_model +def legacy_seresnet50(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet50', pretrained, **model_args) + + +@register_model +def legacy_seresnet101(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet101', pretrained, **model_args) + + +@register_model +def legacy_seresnet152(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet152', pretrained, **model_args) + + +@register_model +def legacy_senet154(pretrained=False, **kwargs): + model_args = dict( + block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16, + downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs) + return _create_senet('legacy_senet154', pretrained, **model_args) + + +@register_model +def legacy_seresnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args) diff --git a/data_processing/MANIQA/timm/models/sknet.py b/data_processing/MANIQA/timm/models/sknet.py new file mode 100644 index 0000000..4dc2aa5 --- /dev/null +++ b/data_processing/MANIQA/timm/models/sknet.py @@ -0,0 +1,215 @@ +""" Selective Kernel Networks (ResNet base) + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268) +and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer +to the original paper with some modifications of my own to better balance param count vs accuracy. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math + +from torch import nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SelectiveKernel, ConvBnAct, create_attn +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), + 'skresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg( + first_conv='conv1.0'), + 'skresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), +} + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = SelectiveKernel( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act(x) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv2 = SelectiveKernel( + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act(x) + return x + + +def _create_skresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def skresnet18(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-18 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet18', pretrained, **model_args) + + +@register_model +def skresnet34(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-34 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet34', pretrained, **model_args) + + +@register_model +def skresnet50(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50', pretrained, **model_args) + + +@register_model +def skresnet50d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50d', pretrained, **model_args) + + +@register_model +def skresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet-50 model in the Select Kernel Paper + """ + sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) + diff --git a/data_processing/MANIQA/timm/models/swin_transformer.py b/data_processing/MANIQA/timm/models/swin_transformer.py new file mode 100644 index 0000000..37b08ba --- /dev/null +++ b/data_processing/MANIQA/timm/models/swin_transformer.py @@ -0,0 +1,657 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .layers import _assert +from .registry import register_model +from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x = self.layers(x) + # x = self.norm(x) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + + +def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + + +@register_model +def swin_base_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_small_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_tiny_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-T @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): + """ Swin-B @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): + """ Swin-B @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): + """ Swin-L @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): + """ Swin-L @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/tnt.py b/data_processing/MANIQA/timm/models/tnt.py new file mode 100644 index 0000000..d52f9ce --- /dev/null +++ b/data_processing/MANIQA/timm/models/tnt.py @@ -0,0 +1,272 @@ +""" Transformer in Transformer (TNT) in PyTorch + +A PyTorch implement of TNT as described in +'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 + +The official mindspore code is released and available at +https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT +""" +import math +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import Mlp, DropPath, trunc_normal_ +from timm.models.layers.helpers import to_2tuple +from timm.models.layers import _assert +from timm.models.registry import register_model +from timm.models.vision_transformer import resize_pos_embed + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tnt_s_patch16_224': _cfg( + url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'tnt_b_patch16_224': _cfg( + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), +} + + +class Attention(nn.Module): + """ Multi-Head Attention + """ + def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop, inplace=True) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop, inplace=True) + + def forward(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) + v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + """ TNT Block + """ + def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., + qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + # Inner transformer + self.norm_in = norm_layer(in_dim) + self.attn_in = Attention( + in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.norm_mlp_in = norm_layer(in_dim) + self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), + out_features=in_dim, act_layer=act_layer, drop=drop) + + self.norm1_proj = norm_layer(in_dim) + self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) + # Outer transformer + self.norm_out = norm_layer(dim) + self.attn_out = Attention( + dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm_mlp = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), + out_features=dim, act_layer=act_layer, drop=drop) + + def forward(self, pixel_embed, patch_embed): + # inner + pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed))) + pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) + # outer + B, N, C = patch_embed.size() + patch_embed = torch.cat( + [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], + dim=1) + patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) + patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) + return pixel_embed, patch_embed + + +class PixelEmbed(nn.Module): + """ Image to Pixel Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # grid_size property necessary for resizing positional embedding + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + num_patches = (self.grid_size[0]) * (self.grid_size[1]) + self.img_size = img_size + self.num_patches = num_patches + self.in_dim = in_dim + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) + self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) + return x + + +class TNT(nn.Module): + """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, + num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.pixel_embed = PixelEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size[0] * new_patch_size[1] + + self.norm1_proj = norm_layer(num_pixel * in_dim) + self.proj = nn.Linear(num_pixel * in_dim, embed_dim) + self.norm2_proj = norm_layer(embed_dim) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + blocks = [] + for i in range(depth): + blocks.append(Block( + dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[i], norm_layer=norm_layer)) + self.blocks = nn.ModuleList(blocks) + self.norm = norm_layer(embed_dim) + + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'patch_pos', 'pixel_pos', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.pos_drop(patch_embed) + + for blk in self.blocks: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return patch_embed[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if state_dict['patch_pos'].shape != model.patch_pos.shape: + state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], + model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + return state_dict + + +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def tnt_s_patch16_224(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + qkv_bias=False, **kwargs) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def tnt_b_patch16_224(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + qkv_bias=False, **kwargs) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) + return model diff --git a/data_processing/MANIQA/timm/models/tresnet.py b/data_processing/MANIQA/timm/models/tresnet.py new file mode 100644 index 0000000..372bfb7 --- /dev/null +++ b/data_processing/MANIQA/timm/models/tresnet.py @@ -0,0 +1,297 @@ +""" +TResNet: High Performance GPU-Dedicated Architecture +https://arxiv.org/pdf/2003.13630.pdf + +Original model: https://github.com/mrT23/TResNet + +""" +from collections import OrderedDict + +import torch +import torch.nn as nn + +from .helpers import build_model_with_cfg +from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from .registry import register_model + +__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': (0, 0, 0), 'std': (1, 1, 1), + 'first_conv': 'body.conv1.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'tresnet_m': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_1k_miil_83_1.pth'), + 'tresnet_m_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_miil_in21k.pth', num_classes=11221), + 'tresnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'), + 'tresnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), + 'tresnet_m_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), + 'tresnet_l_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), + 'tresnet_xl_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') +} + + +def IABN2Float(module: nn.Module) -> nn.Module: + """If `module` is IABN don't use half precision.""" + if isinstance(module, InplaceAbn): + module.float() + for child in module.children(): + IABN2Float(child) + return module + + +def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2): + return nn.Sequential( + nn.Conv2d( + ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False), + InplaceAbn(nf, act_layer=act_layer, act_param=act_param) + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None): + super(BasicBlock, self).__init__() + if stride == 1: + self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3) + else: + if aa_layer is None: + self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3) + else: + self.conv1 = nn.Sequential( + conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity") + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + rd_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None + + def forward(self, x): + if self.downsample is not None: + shortcut = self.downsample(x) + else: + shortcut = x + + out = self.conv1(x) + out = self.conv2(out) + + if self.se is not None: + out = self.se(out) + + out += shortcut + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, + act_layer="leaky_relu", aa_layer=None): + super(Bottleneck, self).__init__() + self.conv1 = conv2d_iabn( + inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) + if stride == 1: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3) + else: + if aa_layer is None: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3) + else: + self.conv2 = nn.Sequential( + conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + reduction_chs = max(planes * self.expansion // 8, 64) + self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None + + self.conv3 = conv2d_iabn( + planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + if self.downsample is not None: + shortcut = self.downsample(x) + else: + shortcut = x + + out = self.conv1(x) + out = self.conv2(out) + if self.se is not None: + out = self.se(out) + + out = self.conv3(out) + out = out + shortcut # no inplace + out = self.relu(out) + + return out + + +class TResNet(nn.Module): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(TResNet, self).__init__() + + aa_layer = BlurPool2d + + # TResnet stages + self.inplanes = int(64 * width_factor) + self.planes = int(64 * width_factor) + conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3) + layer1 = self._make_layer( + BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56 + layer2 = self._make_layer( + BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28 + layer3 = self._make_layer( + Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14 + layer4 = self._make_layer( + Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7 + + # body + self.body = nn.Sequential(OrderedDict([ + ('SpaceToDepth', SpaceToDepthModule()), + ('conv1', conv1), + ('layer1', layer1), + ('layer2', layer2), + ('layer3', layer3), + ('layer4', layer4)])) + + self.feature_info = [ + dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? + dict(num_chs=self.planes, reduction=4, module='body.layer1'), + dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), + dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), + dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), + ] + + # head + self.num_features = (self.planes * 8) * Bottleneck.expansion + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # model initilization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # residual connections special initialization + for m in self.modules(): + if isinstance(m, BasicBlock): + m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero + if isinstance(m, Bottleneck): + m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero + if isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + + def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + layers = [] + if stride == 2: + # avg pooling before 1x1 conv + layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) + layers += [conv2d_iabn( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")] + downsample = nn.Sequential(*layers) + + layers = [] + layers.append(block( + self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer)) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='fast'): + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + return self.body(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_tresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + TResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), + **kwargs) + + +@register_model +def tresnet_m(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_m_miil_in21k(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_miil_in21k', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_m_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/twins.py b/data_processing/MANIQA/timm/models/twins.py new file mode 100644 index 0000000..67a939d --- /dev/null +++ b/data_processing/MANIQA/timm/models/twins.py @@ -0,0 +1,424 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf + +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below + +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .fx_features import register_notrace_module +from .registry import register_model +from .vision_transformer import Attention +from .helpers import build_model_with_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', + ), + 'twins_svt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', + ), + 'twins_svt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', + ), + 'twins_svt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', + ), +} + +Size_ = Tuple[int, int] + + +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + # def forward_mask(self, x, size: Size_): + # B, N, C = x.shape + # H, W = size + # x = x.view(B, H, W, C) + # pad_l = pad_t = 0 + # pad_r = (self.ws - W % self.ws) % self.ws + # pad_b = (self.ws - H % self.ws) % self.ws + # x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + # _, Hp, Wp, _ = x.shape + # _h, _w = Hp // self.ws, Wp // self.ws + # mask = torch.zeros((1, Hp, Wp), device=x.device) + # mask[:, -pad_b:, :].fill_(1) + # mask[:, :, -pad_r:].fill_(1) + # + # x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C + # mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws) + # attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws + # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) + # qkv = self.qkv(x).reshape( + # B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + # # n_h, B, _w*_h, nhead, ws*ws, dim + # q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head + # attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws + # attn = attn + attn_mask.unsqueeze(2) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head + # attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + # x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + # if pad_r > 0 or pad_b > 0: + # x = x[:, :H, :W, :].contiguous() + # x = x.reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): + super().__init__() + self.norm1 = norm_layer(dim) + if ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Size_): + x = x + self.drop_path(self.attn(self.norm1(x), size)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_twins(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + Twins, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + +@register_model +def twins_pcpvt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/data_processing/MANIQA/timm/models/vgg.py b/data_processing/MANIQA/timm/models/vgg.py new file mode 100644 index 0000000..ccaa21d --- /dev/null +++ b/data_processing/MANIQA/timm/models/vgg.py @@ -0,0 +1,263 @@ +"""VGG + +Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for +timm functionality. + +Copyright 2021 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Dict, Any, cast + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .fx_features import register_notrace_module +from .layers import ClassifierHead +from .registry import register_model + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'vgg11': _cfg(url='https://download.pytorch.org/models/vgg11-bbd30ac9.pth'), + 'vgg13': _cfg(url='https://download.pytorch.org/models/vgg13-c768596a.pth'), + 'vgg16': _cfg(url='https://download.pytorch.org/models/vgg16-397923af.pth'), + 'vgg19': _cfg(url='https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'), + 'vgg11_bn': _cfg(url='https://download.pytorch.org/models/vgg11_bn-6002323d.pth'), + 'vgg13_bn': _cfg(url='https://download.pytorch.org/models/vgg13_bn-abd245e5.pth'), + 'vgg16_bn': _cfg(url='https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), + 'vgg19_bn': _cfg(url='https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'), +} + + +cfgs: Dict[str, List[Union[str, int]]] = { + 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method +class ConvMlp(nn.Module): + + def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, + drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None): + super(ConvMlp, self).__init__() + self.input_kernel_size = kernel_size + mid_features = int(out_features * mlp_ratio) + self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True) + self.act1 = act_layer(True) + self.drop = nn.Dropout(drop_rate) + self.fc2 = conv_layer(mid_features, out_features, 1, bias=True) + self.act2 = act_layer(True) + + def forward(self, x): + if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size: + # keep the input size >= 7x7 + output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1])) + x = F.adaptive_avg_pool2d(x, output_size) + x = self.fc1(x) + x = self.act1(x) + x = self.drop(x) + x = self.fc2(x) + x = self.act2(x) + return x + + +class VGG(nn.Module): + + def __init__( + self, + cfg: List[Any], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + mlp_ratio: float = 1.0, + act_layer: nn.Module = nn.ReLU, + conv_layer: nn.Module = nn.Conv2d, + norm_layer: nn.Module = None, + global_pool: str = 'avg', + drop_rate: float = 0., + ) -> None: + super(VGG, self).__init__() + assert output_stride == 32 + self.num_classes = num_classes + self.num_features = 4096 + self.drop_rate = drop_rate + self.feature_info = [] + prev_chs = in_chans + net_stride = 1 + pool_layer = nn.MaxPool2d + layers: List[nn.Module] = [] + for v in cfg: + last_idx = len(layers) - 1 + if v == 'M': + self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}')) + layers += [pool_layer(kernel_size=2, stride=2)] + net_stride *= 2 + else: + v = cast(int, v) + conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1) + if norm_layer is not None: + layers += [conv2d, norm_layer(v), act_layer(inplace=True)] + else: + layers += [conv2d, act_layer(inplace=True)] + prev_chs = v + self.features = nn.Sequential(*layers) + self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}')) + self.pre_logits = ConvMlp( + prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio, + drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer) + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + self._initialize_weights() + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.head = ClassifierHead( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.pre_logits(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.head(x) + return x + + def _initialize_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + k_r = k + k_r = k_r.replace('classifier.0', 'pre_logits.fc1') + k_r = k_r.replace('classifier.3', 'pre_logits.fc2') + k_r = k_r.replace('classifier.6', 'head.fc') + if 'classifier.0.weight' in k: + v = v.reshape(-1, 512, 7, 7) + if 'classifier.3.weight' in k: + v = v.reshape(-1, 4096, 1, 1) + out_dict[k_r] = v + return out_dict + + +def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: + cfg = variant.split('_')[0] + # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5)) + model = build_model_with_cfg( + VGG, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfgs[cfg], + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + pretrained_filter_fn=_filter_fn, + **kwargs) + return model + + +@register_model +def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg11', pretrained=pretrained, **model_args) + + +@register_model +def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg13', pretrained=pretrained, **model_args) + + +@register_model +def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg16', pretrained=pretrained, **model_args) + + +@register_model +def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg19', pretrained=pretrained, **model_args) + + +@register_model +def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration 'E') with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args) \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/visformer.py b/data_processing/MANIQA/timm/models/visformer.py new file mode 100644 index 0000000..62f7730 --- /dev/null +++ b/data_processing/MANIQA/timm/models/visformer.py @@ -0,0 +1,413 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' + ), +) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + self.drop3 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop1(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop3(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size = [x // patch_size for x in img_size] + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size = [x // (patch_size // 2) for x in img_size] + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size = [x // 2 for x in img_size] + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size = [x // (patch_size // 4) for x in img_size] + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + # stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size = [x // (patch_size // 8) for x in img_size] + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size = [x // (patch_size // 8) for x in img_size] + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(self.num_features) + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = self.head(x) + return x + + +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model_cfg = dict( + init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model_cfg = dict( + init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) + return model + + +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model + + + + diff --git a/data_processing/MANIQA/timm/models/vision_transformer.py b/data_processing/MANIQA/timm/models/vision_transformer.py new file mode 100644 index 0000000..d80cb96 --- /dev/null +++ b/data_processing/MANIQA/timm/models/vision_transformer.py @@ -0,0 +1,1045 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" +import math +import logging +import numpy as np +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class RandomMaskingGenerator: + def __init__(self, input_size, mask_ratio): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + + self.height, self.width = input_size + + self.num_patches = self.height * self.width # patch的总数即196 + self.num_mask = int(mask_ratio * self.num_patches) # 196 * 0.75 + + def __repr__(self): + repr_str = "Maks: total patches {}, mask patches {}".format( + self.num_patches, self.num_mask + ) + return repr_str + + def __call__(self): + mask = np.hstack([ + np.zeros(self.num_patches - self.num_mask), + np.ones(self.num_mask), + ]) + np.random.shuffle(mask) + return mask + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_huge_patch14_224': _cfg(url=''), + 'vit_giant_patch14_224': _cfg(url=''), + 'vit_gigantic_patch14_224': _cfg(url=''), + + 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), + 'vit_base_patch16_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), + 'deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), +} + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None, weight_init=''): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + # if self.dist_token is None: + # return self.pre_logits(x[:, 0]) + # else: + # return x[:, 0], x[:, 1] + return x + + def forward(self, x): + x = self.forward_features(x) + # if self.head_dist is not None: + # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + # if self.training and not torch.jit.is_scripting(): + # # during inference, return the average of both classifier predictions + # return x, x_dist + # else: + # return (x + x_dist) / 2 + # else: + # x = self.head(x) + return x[:, 1:] + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + default_cfg=default_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], + **kwargs) + return model + + +@register_model +def vit_tiny_patch16_224(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_384(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base2_patch32_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) + # FIXME experiment + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) + model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_giant_patch14_224(pretrained=False, **kwargs): + """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_gigantic_patch14_224(pretrained=False, **kwargs): + """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_sam(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + # NOTE original SAM weights release worked with representation_size=768 + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_sam(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + # NOTE original SAM weights release worked with representation_size=768 + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_dino(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch8_224_dino(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_dino(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_dino(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer( + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer( + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + return model diff --git a/data_processing/MANIQA/timm/models/vision_transformer_hybrid.py b/data_processing/MANIQA/timm/models/vision_transformer_hybrid.py new file mode 100644 index 0000000..d46297e --- /dev/null +++ b/data_processing/MANIQA/timm/models/vision_transformer_hybrid.py @@ -0,0 +1,363 @@ +""" Hybrid Vision Transformer (ViT) in PyTorch + +A PyTorch implement of the Hybrid Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.TODO + +NOTE These hybrid model definitions depend on code in vision_transformer.py. +They were moved here to keep file sizes sane. + +Hacked together by / Copyright 2020, Ross Wightman +""" +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import StdConv2dSame, StdConv2d, to_2tuple +from .resnet import resnet26d, resnet50d +from .resnetv2 import ResNetV2, create_resnetv2_stem +from .registry import register_model +from timm.models.vision_transformer import _create_vision_transformer + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # hybrid in-1k models (weights from official JAX impl where they exist) + 'vit_tiny_r_s16_p8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + first_conv='patch_embed.backbone.conv'), + 'vit_tiny_r_s16_p8_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_r26_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + ), + 'vit_small_r26_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_r26_s32_224': _cfg(), + 'vit_base_r50_s16_224': _cfg(), + 'vit_base_r50_s16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_r50_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' + ), + 'vit_large_r50_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0 + ), + + # hybrid in-21k models (weights from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), + 'vit_small_r26_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), + 'vit_base_r50_s16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), + + # hybrid models (using timm resnet backbones) + 'vit_small_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_small_resnet50d_s16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet50d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), +} + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer( + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) + + +def _resnetv2(layers=(3, 4, 9), **kwargs): + """ ResNet-V2 backbone helper""" + padding_same = kwargs.get('padding_same', True) + stem_type = 'same' if padding_same else '' + conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) + if len(layers): + backbone = ResNetV2( + layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type=stem_type, conv_layer=conv_layer) + else: + backbone = create_resnetv2_stem( + kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) + return backbone + + +@register_model +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224(pretrained=False, **kwargs): + """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + +@register_model +def vit_large_r50_s32_224(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + +@register_model +def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_resnet26d_224(pretrained=False, **kwargs): + """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): + """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet26d_224(pretrained=False, **kwargs): + """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50d_224(pretrained=False, **kwargs): + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file diff --git a/data_processing/MANIQA/timm/models/vovnet.py b/data_processing/MANIQA/timm/models/vovnet.py new file mode 100644 index 0000000..ec5b3e8 --- /dev/null +++ b/data_processing/MANIQA/timm/models/vovnet.py @@ -0,0 +1,406 @@ +""" VoVNet (V1 & V2) + +Papers: +* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730 +* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Looked at https://github.com/youngwanLEE/vovnet-detectron2 & +https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +for some reference, rewrote most of the code. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\ + create_attn, create_norm_act, get_norm_act_layer + + +# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & +# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +model_cfgs = dict( + vovnet39a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=False, + depthwise=False, + attn='', + ), + vovnet57a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=False, + depthwise=False, + attn='', + + ), + ese_vovnet19b_slim_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + + ), + ese_vovnet19b_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + ), + ese_vovnet19b_slim=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet19b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet57b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet99b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 3, 9, 3], + residual=True, + depthwise=False, + attn='ese', + ), + eca_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='eca', + ), +) +model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b'] +model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + vovnet39a=_cfg(url=''), + vovnet57a=_cfg(url=''), + ese_vovnet19b_slim_dw=_cfg(url=''), + ese_vovnet19b_dw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'), + ese_vovnet19b_slim=_cfg(url=''), + ese_vovnet39b=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'), + ese_vovnet57b=_cfg(url=''), + ese_vovnet99b=_cfg(url=''), + eca_vovnet39b=_cfg(url=''), + ese_vovnet39b_evos=_cfg(url=''), + ese_vovnet99b_iabn=_cfg(url=''), +) + + +class SequentialAppendList(nn.Sequential): + def __init__(self, *args): + super(SequentialAppendList, self).__init__(*args) + + def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: + for i, module in enumerate(self): + if i == 0: + concat_list.append(module(x)) + else: + concat_list.append(module(concat_list[-1])) + x = torch.cat(concat_list, dim=1) + return x + + +class OsaBlock(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, + depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + super(OsaBlock, self).__init__() + + self.residual = residual + self.depthwise = depthwise + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + next_in_chs = in_chs + if self.depthwise and next_in_chs != mid_chs: + assert not residual + self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs) + else: + self.conv_reduction = None + + mid_convs = [] + for i in range(layer_per_block): + if self.depthwise: + conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs) + else: + conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs) + next_in_chs = mid_chs + mid_convs.append(conv) + self.conv_mid = SequentialAppendList(*mid_convs) + + # feature aggregation + next_in_chs = in_chs + layer_per_block * mid_chs + self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs) + + if attn: + self.attn = create_attn(attn, out_chs) + else: + self.attn = None + + self.drop_path = drop_path + + def forward(self, x): + output = [x] + if self.conv_reduction is not None: + x = self.conv_reduction(x) + x = self.conv_mid(x, output) + x = self.conv_concat(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.residual: + x = x + output[0] + return x + + +class OsaStage(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, + residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, + drop_path_rates=None): + super(OsaStage, self).__init__() + + if downsample: + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + else: + self.pool = None + + blocks = [] + for i in range(block_per_stage): + last_block = i == block_per_stage - 1 + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + blocks += [OsaBlock( + in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, + attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) + ] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.pool is not None: + x = self.pool(x) + x = self.blocks(x) + return x + + +class VovNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, + output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): + """ VovNet (v2) + """ + super(VovNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert stem_stride in (4, 2) + assert output_stride == 32 # FIXME support dilation + + stem_chs = cfg["stem_chs"] + stage_conv_chs = cfg["stage_conv_chs"] + stage_out_chs = cfg["stage_out_chs"] + block_per_stage = cfg["block_per_stage"] + layer_per_block = cfg["layer_per_block"] + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + # Stem module + last_stem_stride = stem_stride // 2 + conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), + conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs), + conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs), + ]) + self.feature_info = [dict( + num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')] + current_stride = stem_stride + + # OSA stages + stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage) + in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] + stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs) + stages = [] + for i in range(4): # num_stages + downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 + stages += [OsaStage( + in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, + downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) + ] + self.num_features = stage_out_chs[i] + current_stride *= 2 if downsample else 1 + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + return self.stages(x) + + def forward(self, x): + x = self.forward_features(x) + return self.head(x) + + +def _create_vovnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + VovNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def vovnet39a(pretrained=False, **kwargs): + return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs) + + +@register_model +def vovnet57a(pretrained=False, **kwargs): + return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet57b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet99b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) + + +@register_model +def eca_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) + + +# Experimental Models + +@register_model +def ese_vovnet39b_evos(pretrained=False, **kwargs): + def norm_act_fn(num_features, **nkwargs): + return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) + return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) + + +@register_model +def ese_vovnet99b_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_vovnet( + 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) diff --git a/data_processing/MANIQA/timm/models/xception.py b/data_processing/MANIQA/timm/models/xception.py new file mode 100644 index 0000000..86f558c --- /dev/null +++ b/data_processing/MANIQA/timm/models/xception.py @@ -0,0 +1,232 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" + +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['Xception'] + +default_cfgs = { + 'xception': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', + 'input_size': (3, 299, 299), + 'pool_size': (10, 10), + 'crop_pct': 0.8975, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_channels != in_channels or strides != 1: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_channels) + else: + self.skip = None + + rep = [] + for i in range(reps): + if grow_first: + inc = in_channels if i == 0 else out_channels + outc = out_channels + else: + inc = in_channels + outc = in_channels if i < (reps - 1) else out_channels + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) + rep.append(nn.BatchNorm2d(outc)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = 2048 + + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False) + self.block2 = Block(128, 256, 2, 2) + self.block3 = Block(256, 728, 2, 2) + + self.block4 = Block(728, 728, 3, 1) + self.block5 = Block(728, 728, 3, 1) + self.block6 = Block(728, 728, 3, 1) + self.block7 = Block(728, 728, 3, 1) + + self.block8 = Block(728, 728, 3, 1) + self.block9 = Block(728, 728, 3, 1) + self.block10 = Block(728, 728, 3, 1) + self.block11 = Block(728, 728, 3, 1) + + self.block12 = Block(728, 1024, 2, 2, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(self.num_features) + self.act4 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block2.rep.0'), + dict(num_chs=256, reduction=8, module='block3.rep.0'), + dict(num_chs=728, reduction=16, module='block12.rep.0'), + dict(num_chs=2048, reduction=32, module='act4'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # #------- init weights -------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) + + +@register_model +def xception(pretrained=False, **kwargs): + return _xception('xception', pretrained=pretrained, **kwargs) diff --git a/data_processing/MANIQA/timm/models/xception_aligned.py b/data_processing/MANIQA/timm/models/xception_aligned.py new file mode 100644 index 0000000..ea7f5c0 --- /dev/null +++ b/data_processing/MANIQA/timm/models/xception_aligned.py @@ -0,0 +1,238 @@ +"""Pytorch impl of Aligned Xception 41, 65, 71 + +This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at +https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md + +Hacked together by / Copyright 2020 Ross Wightman +""" +from functools import partial + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, create_conv2d +from .layers.helpers import to_3tuple +from .registry import register_model + +__all__ = ['XceptionAligned'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), + 'crop_pct': 0.903, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + xception41=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'), + xception65=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), + xception71=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), +) + + +class SeparableConv2d(nn.Module): + def __init__( + self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + self.conv_dw = create_conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + self.bn_dw = norm_layer(inplanes) + if act_layer is not None: + self.act_dw = act_layer(inplace=True) + else: + self.act_dw = None + + # pointwise convolution + self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) + self.bn_pw = norm_layer(planes) + if act_layer is not None: + self.act_pw = act_layer(inplace=True) + else: + self.act_pw = None + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.act_dw is not None: + x = self.act_dw(x) + x = self.conv_pw(x) + x = self.bn_pw(x) + if self.act_pw is not None: + x = self.act_pw(x) + return x + + +class XceptionModule(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, pad_type='', + start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None): + super(XceptionModule, self).__init__() + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = ConvBnAct( + in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None) + else: + self.shortcut = None + + separable_act_layer = None if start_with_relu else act_layer + self.stack = nn.Sequential() + for i in range(3): + if start_with_relu: + self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) + self.stack.add_module(f'conv{i + 1}', SeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=separable_act_layer, norm_layer=norm_layer)) + in_chs = out_chs[i] + + def forward(self, x): + skip = x + x = self.stack(x) + if self.shortcut is not None: + skip = self.shortcut(skip) + if not self.no_skip: + x = x + skip + return x + + +class XceptionAligned(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): + super(XceptionAligned, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) + ]) + + curr_dilation = 1 + curr_stride = 2 + self.feature_info = [] + self.blocks = nn.Sequential() + for i, b in enumerate(block_cfg): + b['dilation'] = curr_dilation + if b['stride'] > 1: + self.feature_info += [dict( + num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] + next_stride = curr_stride * b['stride'] + if next_stride > output_stride: + curr_dilation *= b['stride'] + b['stride'] = 1 + else: + curr_stride = next_stride + self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) + self.num_features = self.blocks[-1].out_channels + + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] + + self.head = ClassifierHead( + in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + XceptionAligned, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), + **kwargs) + + +@register_model +def xception41(pretrained=False, **kwargs): + """ Modified Aligned Xception-41 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception41', pretrained=pretrained, **model_args) + + +@register_model +def xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception65', pretrained=pretrained, **model_args) + + +@register_model +def xception71(pretrained=False, **kwargs): + """ Modified Aligned Xception-71 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=1), + dict(in_chs=256, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=1), + dict(in_chs=728, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception71', pretrained=pretrained, **model_args) diff --git a/data_processing/MANIQA/timm/models/xcit.py b/data_processing/MANIQA/timm/models/xcit.py new file mode 100644 index 0000000..9ad162e --- /dev/null +++ b/data_processing/MANIQA/timm/models/xcit.py @@ -0,0 +1,814 @@ +""" Cross-Covariance Image Transformer (XCiT) in PyTorch + +Paper: + - https://arxiv.org/abs/2106.09681 + +Same as the official implementation, with some minor adaptations, original copyright below + - https://github.com/facebookresearch/xcit/blob/master/xcit.py + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +import math +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .vision_transformer import _cfg, Mlp +from .registry import register_model +from .layers import DropPath, trunc_normal_, to_2tuple +from .cait import ClassAttn +from .fx_features import register_notrace_module + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # Patch size 16 + 'xcit_nano_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'), + 'xcit_nano_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'), + 'xcit_nano_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'), + 'xcit_tiny_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'), + 'xcit_tiny_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'), + 'xcit_tiny_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'), + 'xcit_tiny_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'), + 'xcit_small_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'), + 'xcit_small_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'), + 'xcit_small_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'), + 'xcit_small_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_medium_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'), + 'xcit_medium_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'), + 'xcit_medium_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_large_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'), + 'xcit_large_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'), + 'xcit_large_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)), + + # Patch size 8 + 'xcit_nano_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'), + 'xcit_nano_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'), + 'xcit_nano_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'), + 'xcit_tiny_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'), + 'xcit_tiny_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'), + 'xcit_tiny_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'), + 'xcit_tiny_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'), + 'xcit_small_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'), + 'xcit_small_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'), + 'xcit_small_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'), + 'xcit_small_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_medium_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'), + 'xcit_medium_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'), + 'xcit_medium_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_large_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'), + 'xcit_large_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'), + 'xcit_large_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)), +} + + +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method +class PositionalEncodingFourier(nn.Module): + """ + Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. + Based on the official XCiT code + - https://github.com/facebookresearch/xcit/blob/master/xcit.py + """ + + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.weight.device + y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W) + x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1) + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution + batch norm""" + return torch.nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_planes) + ) + + +class ConvPatchEmbed(nn.Module): + """Image to Patch Embedding using multiple convolutional layers""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU): + super().__init__() + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + if patch_size == 16: + self.proj = torch.nn.Sequential( + conv3x3(in_chans, embed_dim // 8, 2), + act_layer(), + conv3x3(embed_dim // 8, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + elif patch_size == 8: + self.proj = torch.nn.Sequential( + conv3x3(in_chans, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + else: + raise('For convolutional projection, patch size has to be in [8, 16]') + + def forward(self, x): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class LPI(nn.Module): + """ + Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the + implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable + 3x3 convolutions with GeLU and BatchNorm2d + """ + + def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3): + super().__init__() + out_features = out_features or in_features + + padding = kernel_size // 2 + + self.conv1 = torch.nn.Conv2d( + in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features) + self.act = act_layer() + self.bn = nn.BatchNorm2d(in_features) + self.conv2 = torch.nn.Conv2d( + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features) + + def forward(self, x, H: int, W: int): + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.act(x) + x = self.bn(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class ClassAttentionBlock(nn.Module): + """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239""" + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): + super().__init__() + self.norm1 = norm_layer(dim) + + self.attn = ClassAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + if eta is not None: # LayerScale Initialization (no layerscale when None) + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.mlp(cls_token) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class XCA(nn.Module): + """ Cross-Covariance Attention (XCA) + Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax + normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, C', N), permute -> (B, N, H, C') + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class XCABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = norm_layer(dim) + self.local_mp = LPI(in_features=dim, act_layer=act_layer) + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + +class XCiT(nn.Module): + """ + Based on timm and DeiT code bases + https://github.com/rwightman/pytorch-image-models/tree/master/timm + https://github.com/facebookresearch/deit/ + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + """ + Args: + img_size (int, tuple): input image size + patch_size (int): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate (constant across all layers) + norm_layer: (nn.Module): normalization layer + cls_attn_layers: (int) Depth of Class attention layers + use_pos_embed: (bool) whether to use positional encoding + eta: (float) layerscale initialization value + tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA + + Notes: + - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch + interaction (class LPI) and the patch embedding (class ConvPatchEmbed) + """ + super().__init__() + img_size = to_2tuple(img_size) + assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \ + '`patch_size` should divide image dimensions evenly' + + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dim) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.blocks = nn.ModuleList([ + XCABlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta) + for _ in range(depth)]) + + self.cls_attn_blocks = nn.ModuleList([ + ClassAttentionBlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) + for _ in range(cls_attn_layers)]) + + # Classifier head + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # Init weights + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, Hp, Wp) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for blk in self.cls_attn_blocks: + x = blk(x) + + x = self.norm(x)[:, 0] + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + state_dict = state_dict['model'] + # For consistency with timm's transformer models while being compatible with official weights source we rename + # pos_embeder to pos_embed. Also account for use_pos_embed == False + use_pos_embed = getattr(model, 'pos_embed', None) is not None + pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')] + for k in pos_embed_keys: + if use_pos_embed: + state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k) + else: + del state_dict[k] + # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors + # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v + if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict(): + num_ca_blocks = len(model.cls_attn_blocks) + for i in range(num_ca_blocks): + qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight') + qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1]) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j] + qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None) + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape(3, -1) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j] + return state_dict + + +def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + model = build_model_with_cfg( + XCiT, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + return model + + +@register_model +def xcit_nano_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs) + model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +# Patch size 8x8 models +@register_model +def xcit_nano_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model diff --git a/data_processing/MANIQA/timm/optim/__init__.py b/data_processing/MANIQA/timm/optim/__init__.py new file mode 100644 index 0000000..7ee4958 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/__init__.py @@ -0,0 +1,15 @@ +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .adamw import AdamW +from .lamb import Lamb +from .lars import Lars +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/data_processing/MANIQA/timm/optim/adabelief.py b/data_processing/MANIQA/timm/optim/adabelief.py new file mode 100644 index 0000000..951d715 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/adabelief.py @@ -0,0 +1,201 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + decoupled_decay (boolean, optional): (default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, + decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, + fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p_fp32) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p_fp32) + + # perform weight decay, check if decoupled weight decay + if group['decoupled_decay']: + if not group['fixed_decay']: + p_fp32.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p_fp32.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p_fp32, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not group['rectify']: + # Default update + step_size = group['lr'] / bias_correction1 + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + elif group['degenerated_to_sgd']: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if num_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/data_processing/MANIQA/timm/optim/adafactor.py b/data_processing/MANIQA/timm/optim/adafactor.py new file mode 100644 index 0000000..0605743 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/adafactor.py @@ -0,0 +1,167 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = not lr + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + + factored, use_first_moment = self._get_options(group, grad.shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) + update = exp_avg + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) + + p_fp32.add_(-update) + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/data_processing/MANIQA/timm/optim/adahessian.py b/data_processing/MANIQA/timm/optim/adahessian.py new file mode 100644 index 0000000..985c67c --- /dev/null +++ b/data_processing/MANIQA/timm/optim/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/data_processing/MANIQA/timm/optim/adamp.py b/data_processing/MANIQA/timm/optim/adamp.py new file mode 100644 index 0000000..ee18763 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/adamp.py @@ -0,0 +1,105 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer +import math + + +def _channel_view(x) -> torch.Tensor: + return x.reshape(x.size(0), -1) + + +def _layer_view(x) -> torch.Tensor: + return x.reshape(1, -1) + + +def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): + wd = 1. + expand_size = (-1,) + (1,) * (len(p.shape) - 1) + for view_func in [_channel_view, _layer_view]: + param_view = view_func(p) + grad_view = view_func(grad) + cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() + + # FIXME this is a problem for PyTorch XLA + if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): + p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) + wd = wd_ratio + return perturb, wd + + return perturb, wd + + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.add_(perturb, alpha=-step_size) + + return loss diff --git a/data_processing/MANIQA/timm/optim/adamw.py b/data_processing/MANIQA/timm/optim/adamw.py new file mode 100644 index 0000000..66478bc --- /dev/null +++ b/data_processing/MANIQA/timm/optim/adamw.py @@ -0,0 +1,122 @@ +""" AdamW Optimizer +Impl copied from PyTorch master + +NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed +someday +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/data_processing/MANIQA/timm/optim/lamb.py b/data_processing/MANIQA/timm/optim/lamb.py new file mode 100644 index 0000000..12c7c49 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/lamb.py @@ -0,0 +1,192 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/data_processing/MANIQA/timm/optim/lars.py b/data_processing/MANIQA/timm/optim/lars.py new file mode 100644 index 0000000..38ca9e0 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/lars.py @@ -0,0 +1,135 @@ +""" PyTorch LARS / LARC Optimizer + +An implementation of LARS (SGD) + LARC in PyTorch + +Based on: + * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Additional cleanup and modifications to properly support PyTorch XLA. + +Copyright 2021 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lars(Optimizer): + """ LARS for PyTorch + + Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate (default: 1.0). + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) + eps (float): eps for division denominator (default: 1e-8) + trust_clip (bool): enable LARC trust ratio clipping (default: False) + always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coeff=0.001, + eps=1e-8, + trust_clip=False, + always_adapt=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coeff=trust_coeff, + eps=eps, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + trust_coeff = group['trust_coeff'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + # apply LARS LR adaptation, LARC clipping, weight decay + # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + if weight_decay != 0 or group['always_adapt']: + w_norm = p.norm(2.0) + g_norm = grad.norm(2.0) + trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, trust_ratio, one_tensor), + one_tensor, + ) + if group['trust_clip']: + trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + grad.add_(p, alpha=weight_decay) + grad.mul_(trust_ratio) + + # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + p.add_(grad, alpha=-group['lr']) + + return loss \ No newline at end of file diff --git a/data_processing/MANIQA/timm/optim/lookahead.py b/data_processing/MANIQA/timm/optim/lookahead.py new file mode 100644 index 0000000..462c3ac --- /dev/null +++ b/data_processing/MANIQA/timm/optim/lookahead.py @@ -0,0 +1,61 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + # NOTE super().__init__() not called on purpose + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self._base_optimizer = base_optimizer + self.param_groups = base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self._base_optimizer.param_groups: + group.setdefault(name, default) + + @torch.no_grad() + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self._base_optimizer.state[fast_p] + if 'lookahead_slow_buff' not in param_state: + param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) + param_state['lookahead_slow_buff'].copy_(fast_p) + slow = param_state['lookahead_slow_buff'] + slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) + fast_p.copy_(slow) + + def sync_lookahead(self): + for group in self._base_optimizer.param_groups: + self.update_slow(group) + + @torch.no_grad() + def step(self, closure=None): + loss = self._base_optimizer.step(closure) + for group in self._base_optimizer.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + return self._base_optimizer.state_dict() + + def load_state_dict(self, state_dict): + self._base_optimizer.load_state_dict(state_dict) + self.param_groups = self._base_optimizer.param_groups diff --git a/data_processing/MANIQA/timm/optim/madgrad.py b/data_processing/MANIQA/timm/optim/madgrad.py new file mode 100644 index 0000000..a76713b --- /dev/null +++ b/data_processing/MANIQA/timm/optim/madgrad.py @@ -0,0 +1,184 @@ +""" PyTorch MADGRAD optimizer + +MADGRAD: https://arxiv.org/abs/2101.11075 + +Code from: https://github.com/facebookresearch/madgrad +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class MADGRAD(torch.optim.Optimizer): + """ + MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic + Optimization. + + .. _MADGRAD: https://arxiv.org/abs/2101.11075 + + MADGRAD is a general purpose optimizer that can be used in place of SGD or + Adam may converge faster and generalize better. Currently GPU-only. + Typically, the same learning rate schedule that is used for SGD or Adam may + be used. The overall learning rate is not comparable to either method and + should be determined by a hyper-parameter sweep. + + MADGRAD requires less weight decay than other methods, often as little as + zero. Momentum values used for SGD or Adam's beta1 should work here also. + + On sparse problems both weight_decay and momentum should be set to 0. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate (default: 1e-2). + momentum (float): + Momentum value in the range [0,1) (default: 0.9). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). + """ + + def __init__( + self, + params: _params_t, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 0, + eps: float = 1e-6, + decoupled_decay: bool = False, + ): + if momentum < 0 or momentum >= 1: + raise ValueError(f"Momentum {momentum} must be in the range [0,1]") + if lr <= 0: + raise ValueError(f"Learning rate {lr} must be positive") + if weight_decay < 0: + raise ValueError(f"Weight decay {weight_decay} must be non-negative") + if eps < 0: + raise ValueError(f"Eps must be non-negative") + + defaults = dict( + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self) -> bool: + return False + + @property + def supports_flat_params(self) -> bool: + return True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + eps = group['eps'] + lr = group['lr'] + eps + weight_decay = group['weight_decay'] + momentum = group['momentum'] + ck = 1 - momentum + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if momentum != 0.0 and grad.is_sparse: + raise RuntimeError("momentum != 0 is not compatible with sparse gradients") + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['grad_sum_sq'] = torch.zeros_like(p) + state['s'] = torch.zeros_like(p) + if momentum != 0: + state['x0'] = torch.clone(p).detach() + + state['step'] += 1 + grad_sum_sq = state['grad_sum_sq'] + s = state['s'] + lamb = lr * math.sqrt(state['step']) + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.mul_(1.0 - group['lr'] * weight_decay) + else: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad.add_(p, alpha=weight_decay) + + if grad.is_sparse: + grad = grad.coalesce() + grad_val = grad._values() + + p_masked = p.sparse_mask(grad) + grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) + s_masked = s.sparse_mask(grad) + + # Compute x_0 from other known quantities + rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) + x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) + + # Dense + sparse op + grad_sq = grad * grad + grad_sum_sq.add_(grad_sq, alpha=lamb) + grad_sum_sq_masked.add_(grad_sq, alpha=lamb) + + rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) + + s.add_(grad, alpha=lamb) + s_masked._values().add_(grad_val, alpha=lamb) + + # update masked copy of p + p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) + # Copy updated masked p to dense p using an add operation + p_masked._values().add_(p_kp1_masked_vals, alpha=-1) + p.add_(p_masked, alpha=-1) + else: + if momentum == 0: + # Compute x_0 from other known quantities + rms = grad_sum_sq.pow(1 / 3).add_(eps) + x0 = p.addcdiv(s, rms, value=1) + else: + x0 = state['x0'] + + # Accumulate second moments + grad_sum_sq.addcmul_(grad, grad, value=lamb) + rms = grad_sum_sq.pow(1 / 3).add_(eps) + + # Update s + s.add_(grad, alpha=lamb) + + # Step + if momentum == 0: + p.copy_(x0.addcdiv(s, rms, value=-1)) + else: + z = x0.addcdiv(s, rms, value=-1) + + # p is a moving average of z + p.mul_(1 - ck).add_(z, alpha=ck) + + return loss diff --git a/data_processing/MANIQA/timm/optim/nadam.py b/data_processing/MANIQA/timm/optim/nadam.py new file mode 100644 index 0000000..6268d5d --- /dev/null +++ b/data_processing/MANIQA/timm/optim/nadam.py @@ -0,0 +1,92 @@ +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + + Originally taken from: https://github.com/pytorch/pytorch/pull/1408 + NOTE: Has potential issues but does work well on some problems. + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + bias_correction2 = 1 - beta2 ** t + + if group['weight_decay'] != 0: + grad = grad.add(p, alpha=group['weight_decay']) + + momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new)) + p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next)) + + return loss diff --git a/data_processing/MANIQA/timm/optim/nvnovograd.py b/data_processing/MANIQA/timm/optim/nvnovograd.py new file mode 100644 index 0000000..fda3f4a --- /dev/null +++ b/data_processing/MANIQA/timm/optim/nvnovograd.py @@ -0,0 +1,120 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(p, alpha=group['weight_decay']) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.add_(exp_avg, alpha=-group['lr']) + + return loss diff --git a/data_processing/MANIQA/timm/optim/optim_factory.py b/data_processing/MANIQA/timm/optim/optim_factory.py new file mode 100644 index 0000000..e174915 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/optim_factory.py @@ -0,0 +1,217 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Optional + +import torch +import torch.nn as nn +import torch.optim as optim + +from .adabelief import AdaBelief +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .lamb import Lamb +from .lars import Lars +from .lookahead import Lookahead +from .madgrad import MADGRAD +from .nadam import Nadam +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict( + opt=cfg.opt, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) + + +def create_optimizer_v2( + model_or_params, + opt: str = 'sgd', + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + filter_bias_and_bn: bool = True, + **kwargs): + """ Create an optimizer. + + TODO currently the model is passed in and all parameters are selected for optimization. + For more general use an interface that allows selection of parameters to optimize and lr groups, one of: + * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion + * expose the parameters interface and leave it up to caller + + Args: + model_or_params (nn.Module): model containing parameters to optimize + opt: name of optimizer to create + lr: initial learning rate + weight_decay: weight decay to apply in optimizer + momentum: momentum for momentum based optimizers (others may use betas via kwargs) + filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay + **kwargs: extra optimizer specific kwargs to pass through + + Returns: + Optimizer + """ + if isinstance(model_or_params, nn.Module): + # a model was passed in, extract parameters and add weight decays to appropriate layers + if weight_decay and filter_bias_and_bn: + skip = {} + if hasattr(model_or_params, 'no_weight_decay'): + skip = model_or_params.no_weight_decay() + parameters = add_weight_decay(model_or_params, weight_decay, skip) + weight_decay = 0. + else: + parameters = model_or_params.parameters() + else: + # iterable of parameters or param groups passed in + parameters = model_or_params + + opt_lower = opt.lower() + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(weight_decay=weight_decay, **kwargs) + if lr is not None: + opt_args.setdefault('lr', lr) + + # basic SGD & related + if opt_lower == 'sgd' or opt_lower == 'nesterov': + # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) + + # adaptive + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == 'adamp': + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == 'nadam': + try: + # NOTE PyTorch >= 1.10 should have native NAdam + optimizer = optim.Nadam(parameters, **opt_args) + except AttributeError: + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'radam': + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == 'adamax': + optimizer = optim.Adamax(parameters, **opt_args) + elif opt_lower == 'adabelief': + optimizer = AdaBelief(parameters, rectify=False, **opt_args) + elif opt_lower == 'radabelief': + optimizer = AdaBelief(parameters, rectify=True, **opt_args) + elif opt_lower == 'adadelta': + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adagrad': + opt_args.setdefault('eps', 1e-8) + optimizer = optim.Adagrad(parameters, **opt_args) + elif opt_lower == 'adafactor': + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'lamb': + optimizer = Lamb(parameters, **opt_args) + elif opt_lower == 'lambc': + optimizer = Lamb(parameters, trust_clip=True, **opt_args) + elif opt_lower == 'larc': + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) + elif opt_lower == 'lars': + optimizer = Lars(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'nlarc': + optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) + elif opt_lower == 'nlars': + optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'madgrad': + optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) + elif opt_lower == 'madgradw': + optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) + elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': + optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'rmsprop': + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == 'rmsproptf': + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) + + # second order + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) + + # NVIDIA fused optimizers, require APEX to be installed + elif opt_lower == 'fusedsgd': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'fusedmomentum': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == 'fusednovograd': + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/data_processing/MANIQA/timm/optim/radam.py b/data_processing/MANIQA/timm/optim/radam.py new file mode 100644 index 0000000..eb8d22e --- /dev/null +++ b/data_processing/MANIQA/timm/optim/radam.py @@ -0,0 +1,89 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_fp32 = p.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) + + # more conservative since it's an approximated value + if num_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + p_fp32.add_(exp_avg, alpha=-step_size) + + p.copy_(p_fp32) + + return loss diff --git a/data_processing/MANIQA/timm/optim/rmsprop_tf.py b/data_processing/MANIQA/timm/optim/rmsprop_tf.py new file mode 100644 index 0000000..0817887 --- /dev/null +++ b/data_processing/MANIQA/timm/optim/rmsprop_tf.py @@ -0,0 +1,139 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2021 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict( + lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if group['decoupled_decay']: + p.mul_(1. - group['lr'] * group['weight_decay']) + else: + grad = grad.add(p, alpha=group['weight_decay']) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha) + # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt + # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + p.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.add_(buf, alpha=-group['lr']) + else: + p.addcdiv_(grad, avg, value=-group['lr']) + + return loss diff --git a/data_processing/MANIQA/timm/optim/sgdp.py b/data_processing/MANIQA/timm/optim/sgdp.py new file mode 100644 index 0000000..baf05fa --- /dev/null +++ b/data_processing/MANIQA/timm/optim/sgdp.py @@ -0,0 +1,70 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer, required +import math + +from .adamp import projection + + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1. + if len(p.shape) > 1: + d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/data_processing/MANIQA/timm/scheduler/__init__.py b/data_processing/MANIQA/timm/scheduler/__init__.py new file mode 100644 index 0000000..f1961b8 --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/__init__.py @@ -0,0 +1,8 @@ +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + +from .scheduler_factory import create_scheduler diff --git a/data_processing/MANIQA/timm/scheduler/cosine_lr.py b/data_processing/MANIQA/timm/scheduler/cosine_lr.py new file mode 100644 index 0000000..84ee349 --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/cosine_lr.py @@ -0,0 +1,119 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise, k-decay. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/data_processing/MANIQA/timm/scheduler/multistep_lr.py b/data_processing/MANIQA/timm/scheduler/multistep_lr.py new file mode 100644 index 0000000..a5d5fe1 --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/multistep_lr.py @@ -0,0 +1,65 @@ +""" MultiStep LR Scheduler + +Basic multi step LR schedule with warmup, noise. +""" +import torch +import bisect +from timm.scheduler.scheduler import Scheduler +from typing import List + +class MultiStepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_curr_decay_steps(self, t): + # find where in the array t goes, + # assumes self.decay_t is sorted + return bisect.bisect_right(self.decay_t, t+1) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/data_processing/MANIQA/timm/scheduler/plateau_lr.py b/data_processing/MANIQA/timm/scheduler/plateau_lr.py new file mode 100644 index 0000000..4f2cacb --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/plateau_lr.py @@ -0,0 +1,113 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__(optimizer, 'lr', initialize=initialize) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.noise_range = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self.noise_range is not None: + if isinstance(self.noise_range, (list, tuple)): + apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] + else: + apply_noise = epoch >= self.noise_range + if apply_noise: + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + g = torch.Generator() + g.manual_seed(self.noise_seed + epoch) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/data_processing/MANIQA/timm/scheduler/poly_lr.py b/data_processing/MANIQA/timm/scheduler/poly_lr.py new file mode 100644 index 0000000..9c351be --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/poly_lr.py @@ -0,0 +1,116 @@ +""" Polynomial Scheduler + +Polynomial LR schedule with warmup, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging + +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class PolyLRScheduler(Scheduler): + """ Polynomial LR Scheduler w/ warmup, noise, and k-decay + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.power = power + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/data_processing/MANIQA/timm/scheduler/scheduler.py b/data_processing/MANIQA/timm/scheduler/scheduler.py new file mode 100644 index 0000000..21d5150 --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/scheduler.py @@ -0,0 +1,105 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + lrs = [v + v * noise for v in lrs] + return lrs diff --git a/data_processing/MANIQA/timm/scheduler/scheduler_factory.py b/data_processing/MANIQA/timm/scheduler/scheduler_factory.py new file mode 100644 index 0000000..72a979c --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/scheduler_factory.py @@ -0,0 +1,107 @@ +""" Scheduler Factory +Hacked together by / Copyright 2021 Ross Wightman +""" +from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + noise_args = dict( + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + cycle_args = dict( + cycle_mul=getattr(args, 'lr_cycle_mul', 1.), + cycle_decay=getattr(args, 'lr_cycle_decay', 0.1), + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + ) + + lr_scheduler = None + if args.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + t_in_epochs=True, + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + **noise_args, + ) + elif args.sched == 'multistep': + lr_scheduler = MultiStepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + **noise_args, + ) + elif args.sched == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + mode=mode, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=0, + **noise_args, + ) + elif args.sched == 'poly': + lr_scheduler = PolyLRScheduler( + optimizer, + power=args.decay_rate, # overloading 'decay_rate' as polynomial power + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + + return lr_scheduler, num_epochs diff --git a/data_processing/MANIQA/timm/scheduler/step_lr.py b/data_processing/MANIQA/timm/scheduler/step_lr.py new file mode 100644 index 0000000..f797e1a --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/step_lr.py @@ -0,0 +1,63 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/data_processing/MANIQA/timm/scheduler/tanh_lr.py b/data_processing/MANIQA/timm/scheduler/tanh_lr.py new file mode 100644 index 0000000..f2d3c9c --- /dev/null +++ b/data_processing/MANIQA/timm/scheduler/tanh_lr.py @@ -0,0 +1,117 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -7., + ub: float = 3., + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if i < self.cycle_limit: + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/data_processing/MANIQA/timm/utils/__init__.py b/data_processing/MANIQA/timm/utils/__init__.py new file mode 100644 index 0000000..b8cef32 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/__init__.py @@ -0,0 +1,13 @@ +from .agc import adaptive_clip_grad +from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad +from .cuda import ApexScaler, NativeScaler +from .distributed import distribute_bn, reduce_tensor +from .jit import set_jit_legacy, set_jit_fuser +from .log import setup_default_logging, FormatterNoInfo +from .metrics import AverageMeter, accuracy +from .misc import natural_key, add_bool_arg +from .model import unwrap_model, get_state_dict, freeze, unfreeze +from .model_ema import ModelEma, ModelEmaV2 +from .random import random_seed +from .summary import update_summary, get_outdir diff --git a/data_processing/MANIQA/timm/utils/agc.py b/data_processing/MANIQA/timm/utils/agc.py new file mode 100644 index 0000000..f514017 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/data_processing/MANIQA/timm/utils/checkpoint_saver.py b/data_processing/MANIQA/timm/utils/checkpoint_saver.py new file mode 100644 index 0000000..6aad74e --- /dev/null +++ b/data_processing/MANIQA/timm/utils/checkpoint_saver.py @@ -0,0 +1,150 @@ +""" Checkpoint Saver + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import glob +import operator +import os +import logging + +import torch + +from .model import unwrap_model, get_state_dict + + +_logger = logging.getLogger(__name__) + + +class CheckpointSaver: + def __init__( + self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + decreasing=False, + max_history=10, + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler + + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None + self.best_metric = None + self.curr_recovery_file = '' + self.last_recovery_file = '' + + # config + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.max_history = max_history + self.unwrap_fn = unwrap_fn + assert self.max_history >= 1 + + def save_checkpoint(self, epoch, metric=None): + assert epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + self._save(tmp_save_path, epoch, metric) + if os.path.exists(last_save_path): + os.unlink(last_save_path) # required for Windows support. + os.rename(tmp_save_path, last_save_path) + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + os.link(last_save_path, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += ' {}\n'.format(c) + _logger.info(checkpoints_str) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric + best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) + if os.path.exists(best_save_path): + os.unlink(best_save_path) + os.link(last_save_path, best_save_path) + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + + def _save(self, save_path, epoch, metric=None): + save_state = { + 'epoch': epoch, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), + 'version': 2, # version < 2 increments epoch before save + } + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, epoch, batch_idx=0): + assert epoch >= 0 + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + self._save(save_path, epoch) + if os.path.exists(self.last_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + os.remove(self.last_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + return files[0] if len(files) else '' diff --git a/data_processing/MANIQA/timm/utils/clip_grad.py b/data_processing/MANIQA/timm/utils/clip_grad.py new file mode 100644 index 0000000..7eb4069 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/data_processing/MANIQA/timm/utils/cuda.py b/data_processing/MANIQA/timm/utils/cuda.py new file mode 100644 index 0000000..9e7bddf --- /dev/null +++ b/data_processing/MANIQA/timm/utils/cuda.py @@ -0,0 +1,55 @@ +""" CUDA / AMP utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False + +from .clip_grad import dispatch_clip_grad + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward(create_graph=create_graph) + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) diff --git a/data_processing/MANIQA/timm/utils/distributed.py b/data_processing/MANIQA/timm/utils/distributed.py new file mode 100644 index 0000000..3c5dba8 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/distributed.py @@ -0,0 +1,28 @@ +""" Distributed training/validation utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import distributed as dist + +from .model import unwrap_model + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def distribute_bn(model, world_size, reduce=False): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) diff --git a/data_processing/MANIQA/timm/utils/jit.py b/data_processing/MANIQA/timm/utils/jit.py new file mode 100644 index 0000000..6039823 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/jit.py @@ -0,0 +1,50 @@ +""" JIT scripting/tracing utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os + +import torch + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) + + +def set_jit_fuser(fuser): + if fuser == "te": + # default fuser should be == 'te' + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(True) + elif fuser == "old" or fuser == "legacy": + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + elif fuser == "nvfuser" or fuser == "nvf": + os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' + os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' + os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_can_fuse_on_cpu() + torch._C._jit_can_fuse_on_gpu() + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_nvfuser_guard_mode(True) + torch._C._jit_set_nvfuser_enabled(True) + else: + assert False, f"Invalid jit fuser ({fuser})" diff --git a/data_processing/MANIQA/timm/utils/log.py b/data_processing/MANIQA/timm/utils/log.py new file mode 100644 index 0000000..c99469e --- /dev/null +++ b/data_processing/MANIQA/timm/utils/log.py @@ -0,0 +1,28 @@ +""" Logging helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import logging.handlers + + +class FormatterNoInfo(logging.Formatter): + def __init__(self, fmt='%(levelname)s: %(message)s'): + logging.Formatter.__init__(self, fmt) + + def format(self, record): + if record.levelno == logging.INFO: + return str(record.getMessage()) + return logging.Formatter.format(self, record) + + +def setup_default_logging(default_level=logging.INFO, log_path=''): + console_handler = logging.StreamHandler() + console_handler.setFormatter(FormatterNoInfo()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) + if log_path: + file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") + file_handler.setFormatter(file_formatter) + logging.root.addHandler(file_handler) diff --git a/data_processing/MANIQA/timm/utils/metrics.py b/data_processing/MANIQA/timm/utils/metrics.py new file mode 100644 index 0000000..9fdbe13 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/metrics.py @@ -0,0 +1,32 @@ +""" Eval metrics and related + +Hacked together by / Copyright 2020 Ross Wightman +""" + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = min(max(topk), output.size()[1]) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] diff --git a/data_processing/MANIQA/timm/utils/misc.py b/data_processing/MANIQA/timm/utils/misc.py new file mode 100644 index 0000000..39c0097 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/misc.py @@ -0,0 +1,18 @@ +""" Misc utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import re + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def add_bool_arg(parser, name, default=False, help=''): + dest_name = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) + group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) + parser.set_defaults(**{dest_name: default}) diff --git a/data_processing/MANIQA/timm/utils/model.py b/data_processing/MANIQA/timm/utils/model.py new file mode 100644 index 0000000..b95c453 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/model.py @@ -0,0 +1,273 @@ +""" Model / state_dict utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import fnmatch + +import torch +from torchvision.ops.misc import FrozenBatchNorm2d + +from .model_ema import ModelEma + + +def unwrap_model(model): + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() + + +def avg_sq_ch_mean(model, input, output): + """ calculate average channel square mean of output activations + """ + return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item() + + +def avg_ch_var(model, input, output): + """ calculate average channel variance of output activations + """ + return torch.mean(output.var(axis=[0, 2, 3])).item() + + +def avg_ch_var_residual(model, input, output): + """ calculate average channel variance of output activations + """ + return torch.mean(output.var(axis=[0, 2, 3])).item() + + +class ActivationStatsHook: + """Iterates through each of `model`'s modules and matches modules using unix pattern + matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is + a match. + + Arguments: + model (nn.Module): model from which we will extract the activation stats + hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string + matching with the name of model's modules. + hook_fns (List[Callable]): List of hook functions to be registered at every + module in `layer_names`. + + Inspiration from https://docs.fast.ai/callback.hook.html. + + Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example + on how to plot Signal Propogation Plots using `ActivationStatsHook`. + """ + + def __init__(self, model, hook_fn_locs, hook_fns): + self.model = model + self.hook_fn_locs = hook_fn_locs + self.hook_fns = hook_fns + if len(hook_fn_locs) != len(hook_fns): + raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ + their lengths are different.") + self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) + for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): + self.register_hook(hook_fn_loc, hook_fn) + + def _create_hook(self, hook_fn): + def append_activation_stats(module, input, output): + out = hook_fn(module, input, output) + self.stats[hook_fn.__name__].append(out) + + return append_activation_stats + + def register_hook(self, hook_fn_loc, hook_fn): + for name, module in self.model.named_modules(): + if not fnmatch.fnmatch(name, hook_fn_loc): + continue + module.register_forward_hook(self._create_hook(hook_fn)) + + +def extract_spp_stats( + model, + hook_fn_locs, + hook_fns, + input_shape=[8, 3, 224, 224]): + """Extract average square channel mean and variance of activations during + forward pass to plot Signal Propogation Plots (SPP). + + Paper: https://arxiv.org/abs/2101.08692 + + Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 + """ + x = torch.normal(0., 1., input_shape) + hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) + _ = model(x) + return hook.stats + + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be (un)frozen. Defaults to [] + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers. + Defaults to `True`. + mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`. + """ + assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' + + if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + # Raise assertion here because we can't convert it in place + raise AssertionError( + "You have provided a batch norm layer as the `root module`. Please use " + "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.") + + if isinstance(submodules, str): + submodules = [submodules] + + named_modules = submodules + submodules = [root_module.get_submodule(m) for m in submodules] + + if not len(submodules): + named_modules, submodules = list(zip(*root_module.named_children())) + + for n, m in zip(named_modules, submodules): + # (Un)freeze parameters + for p in m.parameters(): + p.requires_grad = False if mode == 'freeze' else True + if include_bn_running_stats: + # Helper to add submodule specified as a named_module + def _add_submodule(module, name, submodule): + split = name.rsplit('.', 1) + if len(split) > 1: + module.get_submodule(split[0]).add_module(split[1], submodule) + else: + module.add_module(name, submodule) + + # Freeze batch norm + if mode == 'freeze': + res = freeze_batch_norm_2d(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't + # convert it in place, but will return the converted result. In this case `res` holds the converted + # result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + _add_submodule(root_module, n, res) + # Unfreeze batch norm + else: + res = unfreeze_batch_norm_2d(m) + # Ditto. See note above in mode == 'freeze' branch + if isinstance(m, FrozenBatchNorm2d): + _add_submodule(root_module, n, res) + + +def freeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be frozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and + `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning, + it's good practice to freeze batch norm stats. And note that these are different to the affine parameters + which are just normal PyTorch parameters. Defaults to `True`. + + Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`. + + Examples:: + + >>> model = timm.create_model('resnet18') + >>> # Freeze up to and including layer2 + >>> submodules = [n for n, _ in model.named_children()] + >>> print(submodules) + ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc'] + >>> freeze(model, submodules[:submodules.index('layer2') + 1]) + >>> # Check for yourself that it works as expected + >>> print(model.layer2[0].conv1.weight.requires_grad) + False + >>> print(model.layer3[0].conv1.weight.requires_grad) + True + >>> # Unfreeze + >>> unfreeze(model) + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze") + + +def unfreeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided + as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty + list means that the whole root module will be unfrozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers. + These will be converted to `BatchNorm2d` in place. Defaults to `True`. + + See example in docstring for `freeze`. + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") diff --git a/data_processing/MANIQA/timm/utils/model_ema.py b/data_processing/MANIQA/timm/utils/model_ema.py new file mode 100644 index 0000000..073d5c5 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/model_ema.py @@ -0,0 +1,126 @@ +""" Exponential Moving Average (EMA) of model updates + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn + +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/data_processing/MANIQA/timm/utils/random.py b/data_processing/MANIQA/timm/utils/random.py new file mode 100644 index 0000000..a967998 --- /dev/null +++ b/data_processing/MANIQA/timm/utils/random.py @@ -0,0 +1,9 @@ +import random +import numpy as np +import torch + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) diff --git a/data_processing/MANIQA/timm/utils/summary.py b/data_processing/MANIQA/timm/utils/summary.py new file mode 100644 index 0000000..9f5af9a --- /dev/null +++ b/data_processing/MANIQA/timm/utils/summary.py @@ -0,0 +1,39 @@ +""" Summary utilities + +Hacked together by / Copyright 2020 Ross Wightman +""" +import csv +import os +from collections import OrderedDict +try: + import wandb +except ImportError: + pass + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): + rowd = OrderedDict(epoch=epoch) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) diff --git a/data_processing/MANIQA/utils/inference_process.py b/data_processing/MANIQA/utils/inference_process.py new file mode 100644 index 0000000..9e93e48 --- /dev/null +++ b/data_processing/MANIQA/utils/inference_process.py @@ -0,0 +1,129 @@ +import torch +import numpy as np + + +def sort_file(file_path): + f2 = open(file_path, "r") + lines = f2.readlines() + ret = [] + for line in lines: + line = line[:-1] + ret.append(line) + ret.sort() + + with open('./output.txt', 'w') as f: + for i in ret: + f.write(i + '\n') + + +def five_point_crop(idx, d_img, config): + new_h = config.crop_size + new_w = config.crop_size + b, c, h, w = d_img.shape + if idx == 0: + top = 0 + left = 0 + elif idx == 1: + top = 0 + left = w - new_w + elif idx == 2: + top = h - new_h + left = 0 + elif idx == 3: + top = h - new_h + left = w - new_w + elif idx == 4: + center_h = h // 2 + center_w = w // 2 + top = center_h - new_h // 2 + left = center_w - new_w // 2 + d_img_org = crop_image(top, left, config.crop_size, img=d_img) + + return d_img_org + + +def random_crop(d_img, config): + b, c, h, w = d_img.shape + top = np.random.randint(0, h - config.crop_size) + left = np.random.randint(0, w - config.crop_size) + d_img_org = crop_image(top, left, config.crop_size, img=d_img) + return d_img_org + + +def crop_image(top, left, patch_size, img=None): + tmp_img = img[:, :, top:top + patch_size, left:left + patch_size] + return tmp_img + + +class RandCrop(object): + def __init__(self, patch_size): + self.patch_size = patch_size + + def __call__(self, sample): + # r_img : C x H x W (numpy) + d_img = sample['d_img_org'] + d_name = sample['d_name'] + + c, h, w = d_img.shape + new_h = self.patch_size + new_w = self.patch_size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + ret_d_img = d_img[:, top: top + new_h, left: left + new_w] + sample = { + 'd_img_org': ret_d_img, + 'd_name': d_name + } + + return sample + + +class Normalize(object): + def __init__(self, mean, var): + self.mean = mean + self.var = var + + def __call__(self, sample): + # r_img: C x H x W (numpy) + d_img = sample['d_img_org'] + d_name = sample['d_name'] + + d_img = (d_img - self.mean) / self.var + + sample = {'d_img_org': d_img, 'd_name': d_name} + return sample + + +class RandHorizontalFlip(object): + def __init__(self): + pass + + def __call__(self, sample): + d_img = sample['d_img_org'] + d_name = sample['d_name'] + prob_lr = np.random.random() + # np.fliplr needs HxWxC + if prob_lr > 0.5: + d_img = np.fliplr(d_img).copy() + + sample = { + 'd_img_org': d_img, + 'd_name': d_name + } + return sample + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, sample): + d_img = sample['d_img_org'] + d_name = sample['d_name'] + d_img = torch.from_numpy(d_img).type(torch.FloatTensor) + sample = { + 'd_img_org': d_img, + 'd_name': d_name + } + return sample \ No newline at end of file diff --git a/data_processing/MANIQA/utils/process.py b/data_processing/MANIQA/utils/process.py new file mode 100644 index 0000000..8acad63 --- /dev/null +++ b/data_processing/MANIQA/utils/process.py @@ -0,0 +1,239 @@ +import torch +import numpy as np + + +def random_crop(d_img, config): + b, c, h, w = d_img.shape + top = np.random.randint(0, h - config.crop_size) + left = np.random.randint(0, w - config.crop_size) + d_img_org = crop_image(top, left, config.crop_size, img=d_img) + return d_img_org + + +def crop_image(top, left, patch_size, img=None): + tmp_img = img[:, :, top:top + patch_size, left:left + patch_size] + return tmp_img + + +def five_point_crop(idx, d_img, config): + new_h = config.crop_size + new_w = config.crop_size + b, c, h, w = d_img.shape + if idx == 0: + top = 0 + left = 0 + elif idx == 1: + top = 0 + left = w - new_w + elif idx == 2: + top = h - new_h + left = 0 + elif idx == 3: + top = h - new_h + left = w - new_w + elif idx == 4: + center_h = h // 2 + center_w = w // 2 + top = center_h - new_h // 2 + left = center_w - new_w // 2 + d_img_org = crop_image(top, left, config.crop_size, img=d_img) + + return d_img_org + + +def split_dataset_koniq10k(txt_file_name, split_seed=20): + np.random.seed(split_seed) + object_data = [] + with open(txt_file_name, 'r') as listFile: + for line in listFile: + dis, score = line.split() + dis = dis + if dis not in object_data: + object_data.append(dis) + + np.random.shuffle(object_data) + np.random.seed(20) + + l = len(object_data) + train_name = object_data[:int(l * 0.8)] + val_name = object_data[int(l * 0.8):] + return train_name, val_name + + +def split_dataset_kadid10k(txt_file_name, split_seed=20): + np.random.seed(split_seed) + object_data = [] + with open(txt_file_name, 'r') as listFile: + for line in listFile: + dis, score = line.split() + dis = dis[:-1] + if dis[1:3] not in object_data: + object_data.append(dis[1:3]) + + np.random.shuffle(object_data) + np.random.seed(20) + + l = len(object_data) + train_name = object_data[:int(l * 0.8)] + val_name = object_data[int(l * 0.8):] + return train_name, val_name + + +def split_dataset_tid2013(txt_file_name, split_seed=20): + np.random.seed(split_seed) + object_data = [] + with open(txt_file_name, 'r') as listFile: + for line in listFile: + score, dis = line.split() + if dis[1:3] not in object_data: + object_data.append(dis[1:3]) + + np.random.shuffle(object_data) + np.random.seed(20) + + l = len(object_data) + train_name = object_data[:int(l * 0.8)] + val_name = object_data[int(l * 0.8):] + return train_name, val_name + + +def split_dataset_live(txt_file_name, split_seed=20): + np.random.seed(split_seed) + object_data = [] + with open(txt_file_name, 'r') as listFile: + for line in listFile: + i1, i2, ref, dis, score, h, w = line.split() + if ref[8:] not in object_data: + object_data.append(ref[8:]) + + np.random.shuffle(object_data) + np.random.seed(20) + + l = len(object_data) + train_name = object_data[:int(l * 0.8)] + val_name = object_data[int(l * 0.8):] + return train_name, val_name + + +def split_dataset_csiq(txt_file_name, split_seed=20): + np.random.seed(split_seed) + object_data = [] + with open(txt_file_name, 'r') as listFile: + for line in listFile: + dis, score= line.split() + dis_name, dis_type, idx_img, _ = dis.split(".") + if dis_name not in object_data: + object_data.append(dis_name) + + np.random.shuffle(object_data) + np.random.seed(20) + + l = len(object_data) + train_name = object_data[:int(l * 0.8)] + val_name = object_data[int(l * 0.8):] + return train_name, val_name + + +class RandCrop(object): + def __init__(self, patch_size): + self.patch_size = patch_size + + def __call__(self, sample): + # r_img : C x H x W (numpy) + d_img = sample['d_img_org'] + score = sample['score'] + + c, h, w = d_img.shape + new_h = self.patch_size + new_w = self.patch_size + + # For koniq10k + if h == new_h and w == new_w: + ret_d_img = d_img + else: + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + ret_d_img = d_img[:, top: top + new_h, left: left + new_w] + + sample = { + 'd_img_org': ret_d_img, + 'score': score + } + return sample + + +class Normalize(object): + def __init__(self, mean, var): + self.mean = mean + self.var = var + + def __call__(self, sample): + # r_img: C x H x W (numpy) + d_img = sample['d_img_org'] + score = sample['score'] + d_img = (d_img - self.mean) / self.var + sample = {'d_img_org': d_img, 'score': score} + return sample + + +class RandHorizontalFlip(object): + def __init__(self, prob_aug): + self.prob_aug = prob_aug + + def __call__(self, sample): + d_img = sample['d_img_org'] + score = sample['score'] + + p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) + prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) + + if prob_lr > 0.5: + d_img = np.fliplr(d_img).copy() + + sample = { + 'd_img_org': d_img, + 'score': score + } + return sample + + +class RandRotation(object): + def __init__(self, prob_aug=0.5): + self.prob_aug = prob_aug + self.aug_count = 0 + + def __call__(self, sample): + d_img = sample['d_img_org'] + score = sample['score'] + + p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) + prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) + + if prob_lr > 0.5: + p = np.array([0.33, 0.33, 0.34]) + idx = np.random.choice([1, 2, 3], p=p.ravel()) + d_img = np.rot90(d_img, idx, axes=(1, 2)).copy() + self.aug_count += 1 + + sample = { + 'd_img_org': d_img, + 'score': score, + 'aug_count': self.aug_count + } + return sample + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, sample): + d_img = sample['d_img_org'] + score = sample['score'] + d_img = torch.from_numpy(d_img).type(torch.FloatTensor) + score = torch.from_numpy(score).type(torch.FloatTensor) + sample = { + 'd_img_org': d_img, + 'score': score + } + return sample \ No newline at end of file diff --git a/data_processing/assets/3dpw.png b/data_processing/assets/3dpw.png new file mode 100644 index 0000000..a80dd50 Binary files /dev/null and b/data_processing/assets/3dpw.png differ diff --git a/data_processing/assets/3dpw_crowd.png b/data_processing/assets/3dpw_crowd.png new file mode 100644 index 0000000..0bb148f Binary files /dev/null and b/data_processing/assets/3dpw_crowd.png differ diff --git a/data_processing/assets/directory.md b/data_processing/assets/directory.md new file mode 100644 index 0000000..5ce3b09 --- /dev/null +++ b/data_processing/assets/directory.md @@ -0,0 +1,99 @@ +## Directory +### Root +The `${ROOT}` is described as below. +``` +${ROOT} +|-- assets +|-- common +|-- data +|-- demo +|-- main +|-- output +|-- tool +``` +* `assets` contains config files to reproduce results and some materials used in this repository. +* `data` contains data loading codes and soft links to images and annotations directories. +* `demo` contains demo codes. +* `common` contains kernel codes for I2L-MeshNet. +* `main` contains high-level codes for training or testing the network. +* `output` contains the current experiment's log, trained models, visualized outputs, and test result (only for MuPoTS). +* `tool` contains codes for auxiliary tasks. + +### Data +You need to follow directory structure of the `data` as below. +``` +${ROOT} +|-- data +| |-- J_regressor_extra.npy +| |-- CrowdPose +| | |-- annotations +| | |-- images +| |-- Human36M +| | |-- images +| | |-- annotations +| | |-- J_regressor_h36m_correct.npy +| |-- MuCo +| | |-- data +| | | |-- augmented_set +| | | |-- unaugmented_set +| | | |-- MuCo-3DHP.json +| | | |-- smpl_param.json +| |-- MSCOCO +| | |-- images +| | | |-- train2017 +| | | |-- val2017 +| | |-- annotations +| | |-- J_regressor_coco_hip_smpl.npy +| |-- MPII +| | |-- data +| | | |-- annotations +| | | |-- images +| |-- PW3D +| | |-- data +| | | |-- 3DPW_latest_train.json +| | | |-- 3DPW_latest_validation.json +| | | |-- 3DPW_latest_test.json +| | | |-- 3DPW_validation_crowd_hhrnet_result.json +| | | |-- imageFiles +``` +* Download `J_regressor_*.npy` [[data](https://drive.google.com/drive/folders/187Azod6z13-dS7W5wHerCTgniHYet-yh?usp=sharing)] +* Download CrowdPose data [[data](https://drive.google.com/drive/folders/1qV5Cx5DJLhJVXlfB0vmQrB3ndJXsTZVM?usp=sharing)] +* Download Human3.6M parsed data and SMPL parameters [[data](https://drive.google.com/drive/folders/1r0B9I3XxIIW_jsXjYinDpL6NFcxTZart?usp=share_link)][[SMPL parameters from SMPLify-X](https://drive.google.com/drive/folders/12fCumEgs9PXT-dAaOGq0EDpl9dGKKorF?usp=share_link)] +* Download MuCo parsed/composited data and SMPL parameters [[data](https://drive.google.com/drive/folders/1dfhFa1kBHYKLTKuprNc7xixt3yyKEky5?usp=share_link)][[SMPL parameters from SMPLify-X](https://drive.google.com/drive/folders/1Wm1_6tn1u-_RE1iUlibIWfS75O79aJRz?usp=share_link)] +* Download MS COCO [[data](https://cocodataset.org/#download)] +* Download MPII parsed data [[data](https://drive.google.com/drive/folders/1zQZpfNu0s19tA7Z1SmulP1cDaVfNDDd3?usp=sharing)] +* Download 3DPW parsed data [[data](https://drive.google.com/drive/folders/1_wi6G6h-JFfb9HGccysJwI02zc_S2DVJ?usp=sharing)] +* Download MS COCO / MPII / CrowdPose SMPL parameters from [NeuralAnnot](https://github.com/mks0601/NeuralAnnot_RELEASE) +* All annotation files follow [MS COCO format](http://cocodataset.org/#format-data). +* If you want to add your own dataset, you have to convert it to [MS COCO format](http://cocodataset.org/#format-data). + +If you have a problem with 'Download limit' problem when tried to download dataset from google drive link, please try this trick. +``` +* Go the shared folder, which contains files you want to copy to your drive +* Select all the files you want to copy +* In the upper right corner click on three vertical dots and select “make a copy” +* Then, the file is copied to your personal google drive account. You can download it from your personal account. +``` + + +### Pytorch SMPL layer and VPoser +* For the SMPL layer, I used [smplpytorch](https://github.com/gulvarol/smplpytorch). The repo is already included in `common/utils/smplpytorch`. +* Download `basicModel_f_lbs_10_207_0_v1.0.0.pkl`, `basicModel_m_lbs_10_207_0_v1.0.0.pkl`, and `basicModel_neutral_lbs_10_207_0_v1.0.0.pkl` from [here](https://smpl.is.tue.mpg.de/download.php) (female & male) and [here](http://smplify.is.tue.mpg.de/) (neutral) to `${ROOT}/smplpytorch/smplpytorch/native/models`. +* Download [VPoser](https://github.com/nghorbani/human_body_prior) from [here](https://drive.google.com/drive/folders/1KNw99d4-_6DqYXfBp2S3_4OMQ_nMW0uQ?usp=sharing) and place it under `${ROOT}/common/utils/human_model_files/smpl/VPOSERR_CKPT`. + +### Output +* Create `output` folder as a soft link form (recommended) instead of a folder form because it would take large storage capacity. +* The experiments' directory structure will be created as below. +``` +${ROOT} +|-- output +| |-- ${currrent_experiment_name} +| | |-- log +| | |-- checkpoint +| | |-- result +| | |-- vis +``` +* `log` folder contains training log file. +* `checkpoint` folder contains saved checkpoints for each epoch. +* `result` folder contains final estimation files of MuPoTs generated in the testing stage. +* `vis` folder contains visualized results. diff --git a/data_processing/assets/front_figure.png b/data_processing/assets/front_figure.png new file mode 100644 index 0000000..296b860 Binary files /dev/null and b/data_processing/assets/front_figure.png differ diff --git a/data_processing/assets/running.md b/data_processing/assets/running.md new file mode 100644 index 0000000..d65eaa1 --- /dev/null +++ b/data_processing/assets/running.md @@ -0,0 +1,27 @@ +## Running 3DCrowdNet +In this repository, we provide training and testing codes for 3DPW-Crowd (Table 5) and 3DPW (Table 8). +We use the pre-trained ResNet-50 weights of [xiao2018simple](https://github.com/microsoft/human-pose-estimation.pytorch) to achieve faster convergence, but you can get the same result by training longer. +Download the [file of weights](https://drive.google.com/drive/folders/1UsntO3wdIHOiajcb8oicMhQ82SmFvulp?usp=sharing) and place it under `${ROOT}/tool/`. + +### Train +Use the appropriate config file to reproduce results. +For example, to reproduce 3DPW-Crowd (Table 5), run +```bash +python train.py --amp --continue --gpu 0-3 --cfg ../assets/yaml/3dpw_crowd.yml +``` +Remove `--continue` if you don't want to the use pre-trained ResNet-50 weights. +Add `--exp_dir` argument to resume training. + +> Note: CUDA version may matter on the training time. Normally it takes 2hours per epoch when I used cuda-10.1. But when I use cuda-10.2, it takes 4~6hours per epoch. Pytorch version is 1.6.0. + +### Test +Download the experiment directories from [here](https://drive.google.com/drive/folders/19ntGuC0zaXQa3cCN_2Ox_hWYX3nLLP2J?usp=sharing) and place them under `${ROOT}/output/`. +To evaluate on 3DPW-Crowd (Table 5), run +```bash +python test.py --gpu 0-3 --cfg ../assets/yaml/3dpw_crowd.yml --exp_dir ../output/exp_03-28_18:26 --test_epoch 6 +``` +To evaluate on 3DPW (Table 8), run +```bash +python test.py --gpu 0-3 --cfg ../assets/yaml/3dpw.yml --exp_dir ../output/exp_04-06_23:43 --test_epoch 10 +``` +You can replace the `--exp_dir` with your own experiments. \ No newline at end of file diff --git a/data_processing/assets/yaml/3dpw.yml b/data_processing/assets/yaml/3dpw.yml new file mode 100644 index 0000000..fcf1ed3 --- /dev/null +++ b/data_processing/assets/yaml/3dpw.yml @@ -0,0 +1,10 @@ +trainset_3d: ['Human36M', 'MuCo'] +trainset_2d: ['MSCOCO', 'MPII'] +testset: 'PW3D' + +lr_dec_epoch: [10] +end_epoch: 11 +lr: 0.001 +lr_backbone: 0.0001 +lr_dec_factor: 10 + diff --git a/data_processing/assets/yaml/3dpw_crowd.yml b/data_processing/assets/yaml/3dpw_crowd.yml new file mode 100644 index 0000000..e10df3e --- /dev/null +++ b/data_processing/assets/yaml/3dpw_crowd.yml @@ -0,0 +1,13 @@ +trainset_3d: ['Human36M', 'MuCo'] +trainset_2d: ['MSCOCO', 'CrowdPose','MPII'] +testset: 'PW3D' + +lr_dec_epoch: [4,6] +end_epoch: 10 +lr: 0.0001 +lr_backbone: 0.0001 +lr_dec_factor: 10 + +crowd: True + +finetune: True \ No newline at end of file diff --git a/data_processing/common/base.py b/data_processing/common/base.py new file mode 100644 index 0000000..e0c65b5 --- /dev/null +++ b/data_processing/common/base.py @@ -0,0 +1,199 @@ +import os +import os.path as osp +import math +import time +import glob +import abc +from torch.utils.data import DataLoader +import torch.optim +import torchvision.transforms as transforms +from timer import Timer +from logger import colorlogger +from torch.nn.parallel.data_parallel import DataParallel +from config import cfg +from model import get_model +from dataset import MultipleDatasets +import os +import cv2 + +dataset_list = ['CrowdPose', 'Human36M', 'MPII', 'MSCOCO', 'MuCo', 'PW3D'] +for i in range(len(dataset_list)): + exec('from ' + dataset_list[i] + ' import ' + dataset_list[i]) + + +class Base(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, log_name='logs.txt'): + self.cur_epoch = 0 + + # timer + self.tot_timer = Timer() + self.gpu_timer = Timer() + self.read_timer = Timer() + + # logger + self.logger = colorlogger(cfg.log_dir, log_name=log_name) + + @abc.abstractmethod + def _make_batch_generator(self): + return + + @abc.abstractmethod + def _make_model(self): + return + + +class Trainer(Base): + def __init__(self): + super(Trainer, self).__init__(log_name = 'train_logs.txt') + + def get_optimizer(self, model): + optimizer = torch.optim.Adam([ + {'params': model.module.backbone.parameters(), 'lr': cfg.lr_backbone}, + {'params': model.module.pose2feat.parameters()}, + {'params': model.module.position_net.parameters()}, + {'params': model.module.rotation_net.parameters()}, + ], + lr=cfg.lr) + print('The parameters of backbone, pose2feat, position_net, rotation_net, are added to the optimizer.') + + return optimizer + + def save_model(self, state, epoch,itr = None): + if itr is None: + file_path = osp.join(cfg.model_dir, 'snapshot_{}.pth.tar'.format(str(epoch))) + else: + file_path = osp.join(cfg.model_dir, 'snapshot_{}_{}.pth.tar'.format(str(epoch), str(itr))) + torch.save(state, file_path) + self.logger.info("Write snapshot into {}".format(file_path)) + + def save_visualization(self, inputs, targets, meta_info, epoch,itr): + viz_predicts = self.model.module.get_visualization(inputs, targets, meta_info) + + for idx,viz in enumerate(viz_predicts): + file_path = osp.join(cfg.vis_dir, f'epoch_{epoch:05d}_itr_{itr:05d}_sample_{idx}.png') + if idx ==0: + self.logger.info(f'Write visualization into {file_path}') + cv2.imwrite(file_path, viz) + + def load_model(self, model, optimizer): + model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar')) + cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list]) + ckpt_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar') + ckpt = torch.load(ckpt_path) + start_epoch = ckpt['epoch'] + 1 + + + model.load_state_dict(ckpt['network'], strict=False) + #optimizer.load_state_dict(ckpt['optimizer']) + + self.logger.info('Load checkpoint from {}'.format(ckpt_path)) + return start_epoch, model, optimizer + + def set_lr(self, epoch): + for e in cfg.lr_dec_epoch: + if epoch < e: + break + if epoch < cfg.lr_dec_epoch[-1]: + idx = cfg.lr_dec_epoch.index(e) + for g in self.optimizer.param_groups: + g['lr'] = cfg.lr / (cfg.lr_dec_factor ** idx) + else: + for g in self.optimizer.param_groups: + g['lr'] = cfg.lr / (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch)) + + def get_lr(self): + for g in self.optimizer.param_groups: + cur_lr = g['lr'] + return cur_lr + + def _make_batch_generator(self): + # data load and construct batch generator + self.logger.info("Creating dataset...") + trainset3d_loader = [] + for i in range(len(cfg.trainset_3d)): + print(f'Creating 3d dataset {cfg.trainset_3d[i]}...') + trainset3d_loader.append(eval(cfg.trainset_3d[i])(transforms.ToTensor(), "train")) + trainset2d_loader = [] + for i in range(len(cfg.trainset_2d)): + print(f'Creating 2d dataset {cfg.trainset_2d[i]}...') + trainset2d_loader.append(eval(cfg.trainset_2d[i])(transforms.ToTensor(), "train")) + + if len(trainset3d_loader) > 0 and len(trainset2d_loader) > 0: + self.vertex_num = trainset3d_loader[0].vertex_num + self.joint_num = trainset3d_loader[0].joint_num + trainset3d_loader = MultipleDatasets(trainset3d_loader, make_same_len=False) + trainset2d_loader = MultipleDatasets(trainset2d_loader, make_same_len=False) + trainset_loader = MultipleDatasets([trainset3d_loader, trainset2d_loader], make_same_len=True) + elif len(trainset3d_loader) > 0: + self.vertex_num = trainset3d_loader[0].vertex_num + self.joint_num = trainset3d_loader[0].joint_num + trainset_loader = MultipleDatasets(trainset3d_loader, make_same_len=False) + elif len(trainset2d_loader) > 0: + self.vertex_num = trainset2d_loader[0].vertex_num + self.joint_num = trainset2d_loader[0].joint_num + trainset_loader = MultipleDatasets(trainset2d_loader, make_same_len=False) + else: + assert 0, "Both 3D training set and 2D training set have zero length." + + self.itr_per_epoch = math.ceil(len(trainset_loader) / cfg.num_gpus / cfg.train_batch_size) + self.batch_generator = DataLoader(dataset=trainset_loader, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, num_workers=cfg.num_thread, pin_memory=True) + + def _make_model(self): + # prepare network + self.logger.info("Creating graph and optimizer...") + model = get_model(self.vertex_num, self.joint_num, 'train') + model = DataParallel(model).cuda() + optimizer = self.get_optimizer(model) + if cfg.continue_train: + start_epoch, model, optimizer = self.load_model(model, optimizer) + if cfg.finetune: + start_epoch = 0 + else: + start_epoch = 0 + model.train() + + self.start_epoch = start_epoch + self.model = model + self.optimizer = optimizer + + +class Tester(Base): + def __init__(self, test_epoch): + self.test_epoch = int(test_epoch) + super(Tester, self).__init__(log_name = 'test_logs.txt') + + def _make_batch_generator(self): + # data load and construct batch generator + self.logger.info("Creating dataset...") + testset_loader = eval(cfg.testset)(transforms.ToTensor(), "test") + batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True) + + self.testset = testset_loader + self.vertex_num = testset_loader.vertex_num + self.joint_num = testset_loader.joint_num + self.batch_generator = batch_generator + + def _make_model(self): + model_path = os.path.join(cfg.model_dir, 'snapshot_%d.pth.tar' % self.test_epoch) + assert os.path.exists(model_path), 'Cannot find model at ' + model_path + self.logger.info('Load checkpoint from {}'.format(model_path)) + + # prepare network + self.logger.info("Creating graph...") + model = get_model(self.vertex_num, self.joint_num, 'test') + model = DataParallel(model).cuda() + ckpt = torch.load(model_path) + model.load_state_dict(ckpt['network'], strict=False) + model.eval() + + self.model = model + + def _evaluate(self, outs, cur_sample_idx): + eval_result = self.testset.evaluate(outs, cur_sample_idx) + return eval_result + + def _print_eval_result(self, eval_result): + self.testset.print_eval_result(eval_result) + diff --git a/data_processing/common/logger.py b/data_processing/common/logger.py new file mode 100644 index 0000000..a117118 --- /dev/null +++ b/data_processing/common/logger.py @@ -0,0 +1,50 @@ +import logging +import os + +OK = '\033[92m' +WARNING = '\033[93m' +FAIL = '\033[91m' +END = '\033[0m' + +PINK = '\033[95m' +BLUE = '\033[94m' +GREEN = OK +RED = FAIL +WHITE = END +YELLOW = WARNING + +class colorlogger(): + def __init__(self, log_dir, log_name='train_logs.txt'): + # set log + self._logger = logging.getLogger(log_name) + self._logger.setLevel(logging.INFO) + log_file = os.path.join(log_dir, log_name) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + file_log = logging.FileHandler(log_file, mode='a') + file_log.setLevel(logging.INFO) + console_log = logging.StreamHandler() + console_log.setLevel(logging.INFO) + formatter = logging.Formatter( + "{}%(asctime)s{} %(message)s".format(GREEN, END), + "%m-%d %H:%M:%S") + file_log.setFormatter(formatter) + console_log.setFormatter(formatter) + self._logger.addHandler(file_log) + self._logger.addHandler(console_log) + + def debug(self, msg): + self._logger.debug(str(msg)) + + def info(self, msg): + self._logger.info(str(msg)) + + def warning(self, msg): + self._logger.warning(WARNING + 'WRN: ' + str(msg) + END) + + def critical(self, msg): + self._logger.critical(RED + 'CRI: ' + str(msg) + END) + + def error(self, msg): + self._logger.error(RED + 'ERR: ' + str(msg) + END) + diff --git a/data_processing/common/nets/layer.py b/data_processing/common/nets/layer.py new file mode 100644 index 0000000..11c41db --- /dev/null +++ b/data_processing/common/nets/layer.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F + +from config import cfg + +def make_linear_layers(feat_dims, relu_final=True, use_bn=False): + layers = [] + for i in range(len(feat_dims)-1): + layers.append(nn.Linear(feat_dims[i], feat_dims[i+1])) + + # Do not use ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final): + if use_bn: + layers.append(nn.BatchNorm1d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_conv1d_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv1d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm1d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_deconv_layers(feat_dims, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.ConvTranspose2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + bias=False)) + + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + + +class GraphConvBlock(nn.Module): + def __init__(self, adj, dim_in, dim_out): + super(GraphConvBlock, self).__init__() + self.adj = adj + self.vertex_num = adj.shape[0] + self.fcbn_list = nn.ModuleList([nn.Sequential(*[nn.Linear(dim_in, dim_out), nn.BatchNorm1d(dim_out)]) for _ in range(self.vertex_num)]) + + def forward(self, feat): + batch_size = feat.shape[0] + + # apply kernel for each vertex + feat = torch.stack([fcbn(feat[:,i,:]) for i,fcbn in enumerate(self.fcbn_list)],1) + + # apply adj + adj = self.adj.cuda()[None,:,:].repeat(batch_size,1,1) + feat = torch.bmm(adj, feat) + + # apply activation function + out = F.relu(feat) + return out + + +class GraphResBlock(nn.Module): + def __init__(self, adj, dim): + super(GraphResBlock, self).__init__() + self.adj = adj + self.graph_block1 = GraphConvBlock(adj, dim, dim) + self.graph_block2 = GraphConvBlock(adj, dim, dim) + + def forward(self, feat): + feat_out = self.graph_block1(feat) + feat_out = self.graph_block2(feat_out) + out = feat_out + feat + return out diff --git a/data_processing/common/nets/loss.py b/data_processing/common/nets/loss.py new file mode 100644 index 0000000..0de65eb --- /dev/null +++ b/data_processing/common/nets/loss.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np +from config import cfg + +class CoordLoss(nn.Module): + def __init__(self): + super(CoordLoss, self).__init__() + + def forward(self, coord_out, coord_gt, valid, is_3D=None): + loss = torch.abs(coord_out - coord_gt) * valid + if is_3D is not None: + loss_z = loss[:,:,2:] * is_3D[:,None,None].float() + loss = torch.cat((loss[:,:,:2], loss_z),2) + + return loss + +class ParamLoss(nn.Module): + def __init__(self): + super(ParamLoss, self).__init__() + + def forward(self, param_out, param_gt, valid): + loss = torch.abs(param_out - param_gt) * valid + return loss + +class NormalVectorLoss(nn.Module): + def __init__(self, face): + super(NormalVectorLoss, self).__init__() + self.face = face + + def forward(self, coord_out, coord_gt, valid): + face = torch.LongTensor(self.face).cuda() + + v1_out = coord_out[:,face[:,1],:] - coord_out[:,face[:,0],:] + v1_out = F.normalize(v1_out, p=2, dim=2) # L2 normalize to make unit vector + v2_out = coord_out[:,face[:,2],:] - coord_out[:,face[:,0],:] + v2_out = F.normalize(v2_out, p=2, dim=2) # L2 normalize to make unit vector + v3_out = coord_out[:,face[:,2],:] - coord_out[:,face[:,1],:] + v3_out = F.normalize(v3_out, p=2, dim=2) # L2 nroamlize to make unit vector + + v1_gt = coord_gt[:,face[:,1],:] - coord_gt[:,face[:,0],:] + v1_gt = F.normalize(v1_gt, p=2, dim=2) # L2 normalize to make unit vector + v2_gt = coord_gt[:,face[:,2],:] - coord_gt[:,face[:,0],:] + v2_gt = F.normalize(v2_gt, p=2, dim=2) # L2 normalize to make unit vector + normal_gt = torch.cross(v1_gt, v2_gt, dim=2) + normal_gt = F.normalize(normal_gt, p=2, dim=2) # L2 normalize to make unit vector + + valid_mask = valid[:,face[:,0],:] * valid[:,face[:,1],:] * valid[:,face[:,2],:] + + cos1 = torch.abs(torch.sum(v1_out * normal_gt, 2, keepdim=True)) * valid_mask + cos2 = torch.abs(torch.sum(v2_out * normal_gt, 2, keepdim=True)) * valid_mask + cos3 = torch.abs(torch.sum(v3_out * normal_gt, 2, keepdim=True)) * valid_mask + loss = torch.cat((cos1, cos2, cos3),1) + return loss + +class EdgeLengthLoss(nn.Module): + def __init__(self, face): + super(EdgeLengthLoss, self).__init__() + self.face = face + + def forward(self, coord_out, coord_gt, valid): + face = torch.LongTensor(self.face).cuda() + + d1_out = torch.sqrt(torch.sum((coord_out[:,face[:,0],:] - coord_out[:,face[:,1],:])**2,2,keepdim=True)) + d2_out = torch.sqrt(torch.sum((coord_out[:,face[:,0],:] - coord_out[:,face[:,2],:])**2,2,keepdim=True)) + d3_out = torch.sqrt(torch.sum((coord_out[:,face[:,1],:] - coord_out[:,face[:,2],:])**2,2,keepdim=True)) + + d1_gt = torch.sqrt(torch.sum((coord_gt[:,face[:,0],:] - coord_gt[:,face[:,1],:])**2,2,keepdim=True)) + d2_gt = torch.sqrt(torch.sum((coord_gt[:,face[:,0],:] - coord_gt[:,face[:,2],:])**2,2,keepdim=True)) + d3_gt = torch.sqrt(torch.sum((coord_gt[:,face[:,1],:] - coord_gt[:,face[:,2],:])**2,2,keepdim=True)) + + valid_mask_1 = valid[:,face[:,0],:] * valid[:,face[:,1],:] + valid_mask_2 = valid[:,face[:,0],:] * valid[:,face[:,2],:] + valid_mask_3 = valid[:,face[:,1],:] * valid[:,face[:,2],:] + + diff1 = torch.abs(d1_out - d1_gt) * valid_mask_1 + diff2 = torch.abs(d2_out - d2_gt) * valid_mask_2 + diff3 = torch.abs(d3_out - d3_gt) * valid_mask_3 + loss = torch.cat((diff1, diff2, diff3),1) + return loss + diff --git a/data_processing/common/nets/module.py b/data_processing/common/nets/module.py new file mode 100644 index 0000000..c604e12 --- /dev/null +++ b/data_processing/common/nets/module.py @@ -0,0 +1,152 @@ +import os.path as osp +import torch +import torch.nn as nn +from torch.nn import functional as F +from config import cfg +from human_body_prior.tools.model_loader import load_vposer +import torchgeometry as tgm +from nets.layer import make_conv_layers, make_deconv_layers, make_conv1d_layers, make_linear_layers, GraphConvBlock, GraphResBlock +from utils.mano import MANO +from utils.smpl import SMPL + + +class Pose2Feat(nn.Module): + def __init__(self, joint_num): + super(Pose2Feat, self).__init__() + self.joint_num = joint_num + self.conv = make_conv_layers([64+joint_num,64]) + + def forward(self, img_feat, joint_heatmap): + feat = torch.cat((img_feat, joint_heatmap),1) + feat = self.conv(feat) + return feat + + +class PositionNet(nn.Module): + def __init__(self): + super(PositionNet, self).__init__() + if 'FreiHAND' in cfg.trainset_3d + cfg.trainset_2d + [cfg.testset]: + self.human_model = MANO() + self.joint_num = self.human_model.graph_joint_num + else: + self.human_model = SMPL() + self.joint_num = self.human_model.graph_joint_num + + self.hm_shape = [cfg.output_hm_shape[0] // 8, cfg.output_hm_shape[1] // 8, cfg.output_hm_shape[2] // 8] + self.conv = make_conv_layers([2048, self.joint_num * self.hm_shape[0]], kernel=1, stride=1, padding=0, bnrelu_final=False) + + def soft_argmax_3d(self, heatmap3d): + heatmap3d = heatmap3d.reshape((-1, self.joint_num, self.hm_shape[0] * self.hm_shape[1] * self.hm_shape[2])) + heatmap3d = F.softmax(heatmap3d, 2) + heatmap3d = heatmap3d.reshape((-1, self.joint_num, self.hm_shape[0], self.hm_shape[1], self.hm_shape[2])) + + accu_x = heatmap3d.sum(dim=(2, 3)) + accu_y = heatmap3d.sum(dim=(2, 4)) + accu_z = heatmap3d.sum(dim=(3, 4)) + + accu_x = accu_x * torch.arange(self.hm_shape[2]).float().cuda()[None, None, :] + accu_y = accu_y * torch.arange(self.hm_shape[1]).float().cuda()[None, None, :] + accu_z = accu_z * torch.arange(self.hm_shape[0]).float().cuda()[None, None, :] + + accu_x = accu_x.sum(dim=2, keepdim=True) + accu_y = accu_y.sum(dim=2, keepdim=True) + accu_z = accu_z.sum(dim=2, keepdim=True) + + coord_out = torch.cat((accu_x, accu_y, accu_z), dim=2) + return coord_out + + def forward(self, img_feat): + # joint heatmap + joint_heatmap = self.conv(img_feat).view(-1, self.joint_num, self.hm_shape[0], self.hm_shape[1], self.hm_shape[2]) + + # joint coord + joint_coord = self.soft_argmax_3d(joint_heatmap) + + # joint score sampling + scores = [] + joint_heatmap = joint_heatmap.view(-1, self.joint_num, self.hm_shape[0] * self.hm_shape[1] * self.hm_shape[2]) + joint_heatmap = F.softmax(joint_heatmap, 2) + joint_heatmap = joint_heatmap.view(-1, self.joint_num, self.hm_shape[0], self.hm_shape[1], self.hm_shape[2]) + for j in range(self.joint_num): + x = joint_coord[:, j, 0] / (self.hm_shape[2] - 1) * 2 - 1 + y = joint_coord[:, j, 1] / (self.hm_shape[1] - 1) * 2 - 1 + z = joint_coord[:, j, 2] / (self.hm_shape[0] - 1) * 2 - 1 + grid = torch.stack((x, y, z), 1)[:, None, None, None, :] + score_j = F.grid_sample(joint_heatmap[:, j, None, :, :, :], grid, align_corners=True)[:, 0, 0, 0, 0] # (batch_size) + scores.append(score_j) + scores = torch.stack(scores) # (joint_num, batch_size) + joint_score = scores.permute(1, 0)[:, :, None] # (batch_size, joint_num, 1) + return joint_coord, joint_score + + +class RotationNet(nn.Module): + def __init__(self): + super(RotationNet, self).__init__() + + if 'FreiHAND' in cfg.trainset_3d + cfg.trainset_2d + [cfg.testset]: + self.human_model = MANO() + self.joint_num = self.human_model.graph_joint_num + self.graph_adj = torch.from_numpy(self.human_model.graph_adj).float() + else: + self.human_model = SMPL() + self.joint_num = self.human_model.graph_joint_num + self.graph_adj = torch.from_numpy(self.human_model.graph_adj).float() + + # graph convs + self.graph_block = nn.Sequential(*[\ + GraphConvBlock(self.graph_adj, 2048+4, 128), + GraphResBlock(self.graph_adj, 128), + GraphResBlock(self.graph_adj, 128), + GraphResBlock(self.graph_adj, 128), + GraphResBlock(self.graph_adj, 128)]) + + self.hm_shape = [cfg.output_hm_shape[0] // 8, cfg.output_hm_shape[1] // 8, cfg.output_hm_shape[2] // 8] + + self.root_pose_out = make_linear_layers([self.joint_num*128, 6], relu_final=False) + self.pose_out = make_linear_layers([self.joint_num*128, self.human_model.vposer_code_dim], relu_final=False) # vposer latent code + self.shape_out = make_linear_layers([self.joint_num*128, self.human_model.shape_param_dim], relu_final=False) + self.cam_out = make_linear_layers([self.joint_num*128,3], relu_final=False) + + def sample_image_feature(self, img_feat, joint_coord_img): + img_feat_joints = [] + for j in range(self.joint_num): + x = joint_coord_img [: ,j,0] / (self.hm_shape[2]-1) * 2 - 1 + y = joint_coord_img [: ,j,1] / (self.hm_shape[1]-1) * 2 - 1 + grid = torch.stack( (x, y),1) [:,None,None,:] + img_feat = img_feat.float() + img_feat_j = F.grid_sample(img_feat, grid, align_corners=True) [: , : , 0, 0] # (batch_size, channel_dim) + img_feat_joints.append(img_feat_j) + img_feat_joints = torch.stack(img_feat_joints) # (joint_num, batch_size, channel_dim) + img_feat_joints = img_feat_joints.permute(1, 0 ,2) # (batch_size, joint_num, channel_dim) + return img_feat_joints + + def forward(self, img_feat, joint_coord_img, joint_score): + # pose parameter + img_feat_joints = self.sample_image_feature(img_feat, joint_coord_img) + feat = torch.cat((img_feat_joints, joint_coord_img, joint_score),2) + feat = self.graph_block(feat) + root_pose = self.root_pose_out(feat.view(-1,self.joint_num*128)) + pose_param = self.pose_out(feat.view(-1,self.joint_num*128)) + # shape parameter + shape_param = self.shape_out(feat.view(-1,self.joint_num*128)) + # camera parameter + cam_param = self.cam_out(feat.view(-1,self.joint_num*128)) + + return root_pose, pose_param, shape_param, cam_param + + +class Vposer(nn.Module): + def __init__(self): + super(Vposer, self).__init__() + self.vposer, _ = load_vposer(osp.join(cfg.human_model_path, 'smpl', 'VPOSER_CKPT'), vp_model='snapshot') + self.vposer.eval() + + def forward(self, z): + batch_size = z.shape[0] + body_pose = self.vposer.decode(z, output_type='aa').view(batch_size,-1 ).view(-1,24-3,3) # without root, R_Hand, L_Hand + zero_pose = torch.zeros((batch_size,1,3)).float().cuda() + + # attach zero hand poses + body_pose = torch.cat((body_pose, zero_pose, zero_pose),1) + body_pose = body_pose.view(batch_size,-1) + return body_pose diff --git a/data_processing/common/nets/resnet.py b/data_processing/common/nets/resnet.py new file mode 100644 index 0000000..4c4e6d1 --- /dev/null +++ b/data_processing/common/nets/resnet.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from torchvision.models.resnet import BasicBlock, Bottleneck +from torchvision.models.resnet import model_urls + +class ResNetBackbone(nn.Module): + + def __init__(self, resnet_type): + + resnet_spec = {18: (BasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512], 'resnet18'), + 34: (BasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512], 'resnet34'), + 50: (Bottleneck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], 'resnet50'), + 101: (Bottleneck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048], 'resnet101'), + 152: (Bottleneck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048], 'resnet152')} + block, layers, channels, name = resnet_spec[resnet_type] + + self.name = name + self.inplanes = 64 + super(ResNetBackbone, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, mean=0, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, skip_early=False): + if not skip_early: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + return x + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + return x4 + + def init_weights(self): + org_resnet = torch.utils.model_zoo.load_url(model_urls[self.name]) + # drop orginal resnet fc layer, add 'None' in case of no fc layer, that will raise error + org_resnet.pop('fc.weight', None) + org_resnet.pop('fc.bias', None) + + self.load_state_dict(org_resnet) + print("Initialize resnet from model zoo") + + diff --git a/data_processing/common/timer.py b/data_processing/common/timer.py new file mode 100644 index 0000000..7152ae9 --- /dev/null +++ b/data_processing/common/timer.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import time + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.warm_up = 0 + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + if self.warm_up < 10: + self.warm_up += 1 + return self.diff + else: + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + + if average: + return self.average_time + else: + return self.diff diff --git a/data_processing/common/utils/__init__.py b/data_processing/common/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/dir.py b/data_processing/common/utils/dir.py new file mode 100644 index 0000000..410433d --- /dev/null +++ b/data_processing/common/utils/dir.py @@ -0,0 +1,11 @@ +import os +import sys + +def make_folder(folder_name): + if not os.path.exists(folder_name): + os.makedirs(folder_name) + +def add_pypath(path): + if path not in sys.path: + sys.path.insert(0, path) + diff --git a/data_processing/common/utils/mano.py b/data_processing/common/utils/mano.py new file mode 100644 index 0000000..04fe1d2 --- /dev/null +++ b/data_processing/common/utils/mano.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +import os.path as osp +import json +from config import cfg + +import sys +sys.path.insert(0, cfg.mano_path) +import manopth +from manopth.manolayer import ManoLayer + +class MANO(object): + def __init__(self): + self.layer = self.get_layer() + self.vertex_num = 778 + self.face = self.layer.th_faces.numpy() + self.joint_regressor = self.layer.th_J_regressor.numpy() + + self.joint_num = 21 + self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinly_4') + self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) + self.root_joint_idx = self.joints_name.index('Wrist') + + # add fingertips to joint_regressor + self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) + thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) + self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] + + def get_layer(self): + return ManoLayer(mano_root=osp.join(cfg.mano_path, 'mano', 'models'), flat_hand_mean=False, use_pca=False) # load right hand MANO model + diff --git a/data_processing/common/utils/manopth/.gitignore b/data_processing/common/utils/manopth/.gitignore new file mode 100644 index 0000000..5d99a91 --- /dev/null +++ b/data_processing/common/utils/manopth/.gitignore @@ -0,0 +1,12 @@ +*.sw* +*.bak +*_bak.py + +.cache/ +__pycache__/ +build/ +dist/ +manopth_hassony2.egg-info/ + +mano/models +assets/mano_layer.svg diff --git a/data_processing/common/utils/manopth/LICENSE b/data_processing/common/utils/manopth/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/data_processing/common/utils/manopth/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/data_processing/common/utils/manopth/README.md b/data_processing/common/utils/manopth/README.md new file mode 100644 index 0000000..07ba23d --- /dev/null +++ b/data_processing/common/utils/manopth/README.md @@ -0,0 +1,135 @@ +Manopth +======= + +[MANO](http://mano.is.tue.mpg.de) layer for [PyTorch](https://pytorch.org/) (tested with v0.4 and v1.x) + +ManoLayer is a differentiable PyTorch layer that deterministically maps from pose and shape parameters to hand joints and vertices. +It can be integrated into any architecture as a differentiable layer to predict hand meshes. + +![image](assets/mano_layer.png) + +ManoLayer takes **batched** hand pose and shape vectors and outputs corresponding hand joints and vertices. + +The code is mostly a PyTorch port of the original [MANO](http://mano.is.tue.mpg.de) model from [chumpy](https://github.com/mattloper/chumpy) to [PyTorch](https://pytorch.org/). +It therefore builds directly upon the work of Javier Romero, Dimitrios Tzionas and Michael J. Black. + +This layer was developped and used for the paper *Learning joint reconstruction of hands and manipulated objects* for CVPR19. +See [project page](https://github.com/hassony2/obman) and [demo+training code](https://github.com/hassony2/obman_train). + + +It [reuses](https://github.com/hassony2/manopth/blob/master/manopth/rodrigues_layer.py) [part of the great code](https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py) from the [Pytorch layer for the SMPL body model](https://github.com/MandyMo/pytorch_HMR/blob/master/README.md) by Zhang Xiong ([MandyMo](https://github.com/MandyMo)) to compute the rotation utilities ! + +It also includes in `mano/webuser` partial content of files from the original [MANO](http://mano.is.tue.mpg.de) code ([posemapper.py](mano/webuser/posemapper.py), [serialization.py](mano/webuser/serialization.py), [lbs.py](mano/webuser/lbs.py), [verts.py](mano/webuser/verts.py), [smpl_handpca_wrapper_HAND_only.py](mano/webuser/smpl_handpca_wrapper_HAND_only.py)). + +If you find this code useful for your research, consider citing: + +- the original [MANO](http://mano.is.tue.mpg.de) publication: + +``` +@article{MANO:SIGGRAPHASIA:2017, + title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together}, + author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + publisher = {ACM}, + month = nov, + year = {2017}, + url = {http://doi.acm.org/10.1145/3130800.3130883}, + month_numeric = {11} +} +``` + +- the publication this PyTorch port was developped for: + +``` +@INPROCEEDINGS{hasson19_obman, + title = {Learning joint reconstruction of hands and manipulated objects}, + author = {Hasson, Yana and Varol, G{\"u}l and Tzionas, Dimitris and Kalevatykh, Igor and Black, Michael J. and Laptev, Ivan and Schmid, Cordelia}, + booktitle = {CVPR}, + year = {2019} +} +``` + +The training code associated with this paper, compatible with manopth can be found [here](https://github.com/hassony2/obman_train). The release includes a model trained on a variety of hand datasets. + +# Installation + +## Get code and dependencies + +- `git clone https://github.com/hassony2/manopth` +- `cd manopth` +- Install the dependencies listed in [environment.yml](environment.yml) + - In an existing conda environment, `conda env update -f environment.yml` + - In a new environment, `conda env create -f environment.yml`, will create a conda environment named `manopth` + +## Download MANO pickle data-structures + +- Go to [MANO website](http://mano.is.tue.mpg.de/) +- Create an account by clicking *Sign Up* and provide your information +- Download Models and Code (the downloaded file should have the format `mano_v*_*.zip`). Note that all code and data from this download falls under the [MANO license](http://mano.is.tue.mpg.de/license). +- unzip and copy the `models` folder into the `manopth/mano` folder +- Your folder structure should look like this: +``` +manopth/ + mano/ + models/ + MANO_LEFT.pkl + MANO_RIGHT.pkl + ... + manopth/ + __init__.py + ... +``` + +To check that everything is going well, run `python examples/manopth_mindemo.py`, which should generate from a random hand using the MANO layer ! + +## Install `manopth` package + +To be able to import and use `ManoLayer` in another project, go to your `manopth` folder and run `pip install .` + + +`cd /path/to/other/project` + +You can now use `from manopth import ManoLayer` in this other project! + +# Usage + +## Minimal usage script + +See [examples/manopth_mindemo.py](examples/manopth_mindemo.py) + +Simple forward pass with random pose and shape parameters through MANO layer + +```python +import torch +from manopth.manolayer import ManoLayer +from manopth import demo + +batch_size = 10 +# Select number of principal components for pose space +ncomps = 6 + +# Initialize MANO layer +mano_layer = ManoLayer(mano_root='mano/models', use_pca=True, ncomps=ncomps) + +# Generate random shape parameters +random_shape = torch.rand(batch_size, 10) +# Generate random pose parameters, including 3 values for global axis-angle rotation +random_pose = torch.rand(batch_size, ncomps + 3) + +# Forward pass through MANO layer +hand_verts, hand_joints = mano_layer(random_pose, random_shape) +demo.display_hand({'verts': hand_verts, 'joints': hand_joints}, mano_faces=mano_layer.th_faces) +``` + +Result : + +![random hand](assets/random_hand.png) + +## Demo + +With more options, forward and backward pass, and a loop for quick profiling, look at [examples/manopth_demo.py](examples/manopth_demo.py). + +You can run it locally with: + +`python examples/manopth_demo.py` + diff --git a/data_processing/common/utils/manopth/assets/mano_layer.png b/data_processing/common/utils/manopth/assets/mano_layer.png new file mode 100644 index 0000000..2365263 Binary files /dev/null and b/data_processing/common/utils/manopth/assets/mano_layer.png differ diff --git a/data_processing/common/utils/manopth/assets/random_hand.png b/data_processing/common/utils/manopth/assets/random_hand.png new file mode 100644 index 0000000..1331322 Binary files /dev/null and b/data_processing/common/utils/manopth/assets/random_hand.png differ diff --git a/data_processing/common/utils/manopth/environment.yml b/data_processing/common/utils/manopth/environment.yml new file mode 100644 index 0000000..667ae3a --- /dev/null +++ b/data_processing/common/utils/manopth/environment.yml @@ -0,0 +1,12 @@ +name: manopth + +dependencies: + - opencv + - python=3.7 + - matplotlib + - numpy + - pytorch + - tqdm + - git + - pip: + - git+https://github.com/hassony2/chumpy.git diff --git a/data_processing/common/utils/manopth/examples/manopth_demo.py b/data_processing/common/utils/manopth/examples/manopth_demo.py new file mode 100644 index 0000000..72a5186 --- /dev/null +++ b/data_processing/common/utils/manopth/examples/manopth_demo.py @@ -0,0 +1,91 @@ +import argparse + +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import torch +from tqdm import tqdm + +from manopth import argutils +from manopth.manolayer import ManoLayer +from manopth.demo import display_hand + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--cuda', action='store_true') + parser.add_argument( + '--no_display', + action='store_true', + help="Disable display output of ManoLayer given random inputs") + parser.add_argument('--side', default='left', choices=['left', 'right']) + parser.add_argument('--random_shape', action='store_true', help="Random hand shape") + parser.add_argument('--rand_mag', type=float, default=1, help="Controls pose variability") + parser.add_argument( + '--flat_hand_mean', + action='store_true', + help="Use flat hand as mean instead of average hand pose") + parser.add_argument( + '--iters', + type=int, + default=1, + help= + "Use for quick profiling of forward and backward pass accross ManoLayer" + ) + parser.add_argument('--mano_root', default='mano/models') + parser.add_argument('--root_rot_mode', default='axisang', choices=['rot6d', 'axisang']) + parser.add_argument('--no_pca', action='store_true', help="Give axis-angle or rotation matrix as inputs instead of PCA coefficients") + parser.add_argument('--joint_rot_mode', default='axisang', choices=['rotmat', 'axisang'], help="Joint rotation inputs") + parser.add_argument( + '--mano_ncomps', default=6, type=int, help="Number of PCA components") + args = parser.parse_args() + + argutils.print_args(args) + + layer = ManoLayer( + flat_hand_mean=args.flat_hand_mean, + side=args.side, + mano_root=args.mano_root, + ncomps=args.mano_ncomps, + use_pca=not args.no_pca, + root_rot_mode=args.root_rot_mode, + joint_rot_mode=args.joint_rot_mode) + if args.root_rot_mode == 'axisang': + rot = 3 + else: + rot = 6 + print(rot) + if args.no_pca: + args.mano_ncomps = 45 + + # Generate random pose coefficients + pose_params = args.rand_mag * torch.rand(args.batch_size, args.mano_ncomps + rot) + pose_params.requires_grad = True + if args.random_shape: + shape = torch.rand(args.batch_size, 10) + else: + shape = torch.zeros(1) # Hack to act like None for PyTorch JIT + if args.cuda: + pose_params = pose_params.cuda() + shape = shape.cuda() + layer.cuda() + + # Loop for forward/backward quick profiling + for idx in tqdm(range(args.iters)): + # Forward pass + verts, Jtr = layer(pose_params, th_betas=shape) + + # Backward pass + loss = torch.norm(verts) + loss.backward() + + if not args.no_display: + verts, Jtr = layer(pose_params, th_betas=shape) + joints = Jtr.cpu().detach() + verts = verts.cpu().detach() + + # Draw obtained vertices and joints + display_hand({ + 'verts': verts, + 'joints': joints + }, + mano_faces=layer.th_faces) diff --git a/data_processing/common/utils/manopth/examples/manopth_mindemo.py b/data_processing/common/utils/manopth/examples/manopth_mindemo.py new file mode 100644 index 0000000..10098a0 --- /dev/null +++ b/data_processing/common/utils/manopth/examples/manopth_mindemo.py @@ -0,0 +1,24 @@ +import torch +from manopth.manolayer import ManoLayer +from manopth import demo + +batch_size = 10 +# Select number of principal components for pose space +ncomps = 6 + +# Initialize MANO layer +mano_layer = ManoLayer( + mano_root='mano/models', use_pca=True, ncomps=ncomps, flat_hand_mean=False) + +# Generate random shape parameters +random_shape = torch.rand(batch_size, 10) +# Generate random pose parameters, including 3 values for global axis-angle rotation +random_pose = torch.rand(batch_size, ncomps + 3) + +# Forward pass through MANO layer +hand_verts, hand_joints = mano_layer(random_pose, random_shape) +demo.display_hand({ + 'verts': hand_verts, + 'joints': hand_joints +}, + mano_faces=mano_layer.th_faces) diff --git a/data_processing/common/utils/manopth/mano/__init__.py b/data_processing/common/utils/manopth/mano/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/manopth/mano/webuser/__init__.py b/data_processing/common/utils/manopth/mano/webuser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/manopth/mano/webuser/lbs.py b/data_processing/common/utils/manopth/mano/webuser/lbs.py new file mode 100644 index 0000000..5acaf84 --- /dev/null +++ b/data_processing/common/utils/manopth/mano/webuser/lbs.py @@ -0,0 +1,84 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +from mano.webuser.posemapper import posemap +import chumpy +import numpy as np + + +def global_rigid_transformation(pose, J, kintree_table, xp): + results = {} + pose = pose.reshape((-1, 3)) + id_to_col = {kintree_table[1, i]: i for i in range(kintree_table.shape[1])} + parent = { + i: id_to_col[kintree_table[0, i]] + for i in range(1, kintree_table.shape[1]) + } + + if xp == chumpy: + from mano.webuser.posemapper import Rodrigues + rodrigues = lambda x: Rodrigues(x) + else: + import cv2 + rodrigues = lambda x: cv2.Rodrigues(x)[0] + + with_zeros = lambda x: xp.vstack((x, xp.array([[0.0, 0.0, 0.0, 1.0]]))) + results[0] = with_zeros( + xp.hstack((rodrigues(pose[0, :]), J[0, :].reshape((3, 1))))) + + for i in range(1, kintree_table.shape[1]): + results[i] = results[parent[i]].dot( + with_zeros( + xp.hstack((rodrigues(pose[i, :]), ((J[i, :] - J[parent[i], :] + ).reshape((3, 1))))))) + + pack = lambda x: xp.hstack([np.zeros((4, 3)), x.reshape((4, 1))]) + + results = [results[i] for i in sorted(results.keys())] + results_global = results + + if True: + results2 = [ + results[i] - (pack(results[i].dot(xp.concatenate(((J[i, :]), 0))))) + for i in range(len(results)) + ] + results = results2 + result = xp.dstack(results) + return result, results_global + + +def verts_core(pose, v, J, weights, kintree_table, want_Jtr=False, xp=chumpy): + A, A_global = global_rigid_transformation(pose, J, kintree_table, xp) + T = A.dot(weights.T) + + rest_shape_h = xp.vstack((v.T, np.ones((1, v.shape[0])))) + + v = (T[:, 0, :] * rest_shape_h[0, :].reshape( + (1, -1)) + T[:, 1, :] * rest_shape_h[1, :].reshape( + (1, -1)) + T[:, 2, :] * rest_shape_h[2, :].reshape( + (1, -1)) + T[:, 3, :] * rest_shape_h[3, :].reshape((1, -1))).T + + v = v[:, :3] + + if not want_Jtr: + return v + Jtr = xp.vstack([g[:3, 3] for g in A_global]) + return (v, Jtr) diff --git a/data_processing/common/utils/manopth/mano/webuser/posemapper.py b/data_processing/common/utils/manopth/mano/webuser/posemapper.py new file mode 100644 index 0000000..9a9ae42 --- /dev/null +++ b/data_processing/common/utils/manopth/mano/webuser/posemapper.py @@ -0,0 +1,55 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import chumpy as ch +import numpy as np +import cv2 + + +class Rodrigues(ch.Ch): + dterms = 'rt' + + def compute_r(self): + return cv2.Rodrigues(self.rt.r)[0] + + def compute_dr_wrt(self, wrt): + if wrt is self.rt: + return cv2.Rodrigues(self.rt.r)[1].T + + +def lrotmin(p): + if isinstance(p, np.ndarray): + p = p.ravel()[3:] + return np.concatenate( + [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel() + for pp in p.reshape((-1, 3))]).ravel() + if p.ndim != 2 or p.shape[1] != 3: + p = p.reshape((-1, 3)) + p = p[1:] + return ch.concatenate([(Rodrigues(pp) - ch.eye(3)).ravel() + for pp in p]).ravel() + + +def posemap(s): + if s == 'lrotmin': + return lrotmin + else: + raise Exception('Unknown posemapping: %s' % (str(s), )) diff --git a/data_processing/common/utils/manopth/mano/webuser/serialization.py b/data_processing/common/utils/manopth/mano/webuser/serialization.py new file mode 100644 index 0000000..9cbdd7e --- /dev/null +++ b/data_processing/common/utils/manopth/mano/webuser/serialization.py @@ -0,0 +1,94 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +__all__ = ['load_model', 'save_model'] + +import numpy as np +import pickle +import chumpy as ch +from chumpy.ch import MatVecMult +from mano.webuser.posemapper import posemap +from mano.webuser.verts import verts_core + +def ready_arguments(fname_or_dict): + + if not isinstance(fname_or_dict, dict): + dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + else: + dd = fname_or_dict + + backwards_compatibility_replacements(dd) + + want_shapemodel = 'shapedirs' in dd + nposeparms = dd['kintree_table'].shape[1] * 3 + + if 'trans' not in dd: + dd['trans'] = np.zeros(3) + if 'pose' not in dd: + dd['pose'] = np.zeros(nposeparms) + if 'shapedirs' in dd and 'betas' not in dd: + dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) + + for s in [ + 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', + 'betas', 'J' + ]: + if (s in dd) and not hasattr(dd[s], 'dterms'): + dd[s] = ch.array(dd[s]) + + if want_shapemodel: + dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] + v_shaped = dd['v_shaped'] + J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) + J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) + J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) + dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T + dd['v_posed'] = v_shaped + dd['posedirs'].dot( + posemap(dd['bs_type'])(dd['pose'])) + else: + dd['v_posed'] = dd['v_template'] + dd['posedirs'].dot( + posemap(dd['bs_type'])(dd['pose'])) + + return dd + + +def load_model(fname_or_dict): + dd = ready_arguments(fname_or_dict) + + args = { + 'pose': dd['pose'], + 'v': dd['v_posed'], + 'J': dd['J'], + 'weights': dd['weights'], + 'kintree_table': dd['kintree_table'], + 'xp': ch, + 'want_Jtr': True, + 'bs_style': dd['bs_style'] + } + + result, Jtr = verts_core(**args) + result = result + dd['trans'].reshape((1, 3)) + result.J_transformed = Jtr + dd['trans'].reshape((1, 3)) + + for k, v in dd.items(): + setattr(result, k, v) + + return result diff --git a/data_processing/common/utils/manopth/mano/webuser/smpl_handpca_wrapper_HAND_only.py b/data_processing/common/utils/manopth/mano/webuser/smpl_handpca_wrapper_HAND_only.py new file mode 100644 index 0000000..de279f9 --- /dev/null +++ b/data_processing/common/utils/manopth/mano/webuser/smpl_handpca_wrapper_HAND_only.py @@ -0,0 +1,150 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +def ready_arguments(fname_or_dict, posekey4vposed='pose'): + import numpy as np + import pickle + import chumpy as ch + from chumpy.ch import MatVecMult + from mano.webuser.posemapper import posemap + + if not isinstance(fname_or_dict, dict): + dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # dd = pickle.load(open(fname_or_dict, 'rb')) + else: + dd = fname_or_dict + + want_shapemodel = 'shapedirs' in dd + nposeparms = dd['kintree_table'].shape[1] * 3 + + if 'trans' not in dd: + dd['trans'] = np.zeros(3) + if 'pose' not in dd: + dd['pose'] = np.zeros(nposeparms) + if 'shapedirs' in dd and 'betas' not in dd: + dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) + + for s in [ + 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', + 'betas', 'J' + ]: + if (s in dd) and not hasattr(dd[s], 'dterms'): + dd[s] = ch.array(dd[s]) + + assert (posekey4vposed in dd) + if want_shapemodel: + dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] + v_shaped = dd['v_shaped'] + J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) + J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) + J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) + dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res) + else: + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd_add = dd['posedirs'].dot(pose_map_res) + dd['v_posed'] = dd['v_template'] + dd_add + + return dd + + +def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None): + ''' This model loads the fully articulable HAND SMPL model, + and replaces the pose DOFS by ncomps from PCA''' + + from mano.webuser.verts import verts_core + import numpy as np + import chumpy as ch + import pickle + import scipy.sparse as sp + np.random.seed(1) + + if not isinstance(fname_or_dict, dict): + smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # smpl_data = pickle.load(open(fname_or_dict, 'rb')) + else: + smpl_data = fname_or_dict + + rot = 3 # for global orientation!!! + + hands_components = smpl_data['hands_components'] + hands_mean = np.zeros(hands_components.shape[ + 1]) if flat_hand_mean else smpl_data['hands_mean'] + hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps] + + selected_components = np.vstack((hands_components[:ncomps])) + hands_mean = hands_mean.copy() + + pose_coeffs = ch.zeros(rot + selected_components.shape[0]) + full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components) + + smpl_data['fullpose'] = ch.concatenate((pose_coeffs[:rot], + hands_mean + full_hand_pose)) + smpl_data['pose'] = pose_coeffs + + Jreg = smpl_data['J_regressor'] + if not sp.issparse(Jreg): + smpl_data['J_regressor'] = (sp.csc_matrix( + (Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape)) + + # slightly modify ready_arguments to make sure that it uses the fullpose + # (which will NOT be pose) for the computation of posedirs + dd = ready_arguments(smpl_data, posekey4vposed='fullpose') + + # create the smpl formula with the fullpose, + # but expose the PCA coefficients as smpl.pose for compatibility + args = { + 'pose': dd['fullpose'], + 'v': dd['v_posed'], + 'J': dd['J'], + 'weights': dd['weights'], + 'kintree_table': dd['kintree_table'], + 'xp': ch, + 'want_Jtr': True, + 'bs_style': dd['bs_style'], + } + + result_previous, meta = verts_core(**args) + + result = result_previous + dd['trans'].reshape((1, 3)) + result.no_translation = result_previous + + if meta is not None: + for field in ['Jtr', 'A', 'A_global', 'A_weighted']: + if (hasattr(meta, field)): + setattr(result, field, getattr(meta, field)) + + setattr(result, 'Jtr', meta) + if hasattr(result, 'Jtr'): + result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3)) + + for k, v in dd.items(): + setattr(result, k, v) + + if v_template is not None: + result.v_template[:] = v_template + + return result + + +if __name__ == '__main__': + load_model() diff --git a/data_processing/common/utils/manopth/mano/webuser/verts.py b/data_processing/common/utils/manopth/mano/webuser/verts.py new file mode 100644 index 0000000..5fd9550 --- /dev/null +++ b/data_processing/common/utils/manopth/mano/webuser/verts.py @@ -0,0 +1,124 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import chumpy +import mano.webuser.lbs as lbs +from mano.webuser.posemapper import posemap +import scipy.sparse as sp +from chumpy.ch import MatVecMult + + +def ischumpy(x): + return hasattr(x, 'dterms') + + +def verts_decorated(trans, + pose, + v_template, + J_regressor, + weights, + kintree_table, + bs_style, + f, + bs_type=None, + posedirs=None, + betas=None, + shapedirs=None, + want_Jtr=False): + + for which in [ + trans, pose, v_template, weights, posedirs, betas, shapedirs + ]: + if which is not None: + assert ischumpy(which) + + v = v_template + + if shapedirs is not None: + if betas is None: + betas = chumpy.zeros(shapedirs.shape[-1]) + v_shaped = v + shapedirs.dot(betas) + else: + v_shaped = v + + if posedirs is not None: + v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose)) + else: + v_posed = v_shaped + + v = v_posed + + if sp.issparse(J_regressor): + J_tmpx = MatVecMult(J_regressor, v_shaped[:, 0]) + J_tmpy = MatVecMult(J_regressor, v_shaped[:, 1]) + J_tmpz = MatVecMult(J_regressor, v_shaped[:, 2]) + J = chumpy.vstack((J_tmpx, J_tmpy, J_tmpz)).T + else: + assert (ischumpy(J)) + + assert (bs_style == 'lbs') + result, Jtr = lbs.verts_core( + pose, v, J, weights, kintree_table, want_Jtr=True, xp=chumpy) + + tr = trans.reshape((1, 3)) + result = result + tr + Jtr = Jtr + tr + + result.trans = trans + result.f = f + result.pose = pose + result.v_template = v_template + result.J = J + result.J_regressor = J_regressor + result.weights = weights + result.kintree_table = kintree_table + result.bs_style = bs_style + result.bs_type = bs_type + if posedirs is not None: + result.posedirs = posedirs + result.v_posed = v_posed + if shapedirs is not None: + result.shapedirs = shapedirs + result.betas = betas + result.v_shaped = v_shaped + if want_Jtr: + result.J_transformed = Jtr + return result + + +def verts_core(pose, + v, + J, + weights, + kintree_table, + bs_style, + want_Jtr=False, + xp=chumpy): + + if xp == chumpy: + assert (hasattr(pose, 'dterms')) + assert (hasattr(v, 'dterms')) + assert (hasattr(J, 'dterms')) + assert (hasattr(weights, 'dterms')) + + assert (bs_style == 'lbs') + result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp) + return result diff --git a/data_processing/common/utils/manopth/manopth/__init__.py b/data_processing/common/utils/manopth/manopth/__init__.py new file mode 100644 index 0000000..e27cf86 --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/__init__.py @@ -0,0 +1 @@ +name = 'manopth' diff --git a/data_processing/common/utils/manopth/manopth/argutils.py b/data_processing/common/utils/manopth/manopth/argutils.py new file mode 100644 index 0000000..7e86eb0 --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/argutils.py @@ -0,0 +1,51 @@ +import datetime +import os +import pickle +import subprocess +import sys + + +def print_args(args): + opts = vars(args) + print('======= Options ========') + for k, v in sorted(opts.items()): + print('{}: {}'.format(k, v)) + print('========================') + + +def save_args(args, save_folder, opt_prefix='opt', verbose=True): + opts = vars(args) + # Create checkpoint folder + if not os.path.exists(save_folder): + os.makedirs(save_folder, exist_ok=True) + + # Save options + opt_filename = '{}.txt'.format(opt_prefix) + opt_path = os.path.join(save_folder, opt_filename) + with open(opt_path, 'a') as opt_file: + opt_file.write('====== Options ======\n') + for k, v in sorted(opts.items()): + opt_file.write( + '{option}: {value}\n'.format(option=str(k), value=str(v))) + opt_file.write('=====================\n') + opt_file.write('launched {} at {}\n'.format( + str(sys.argv[0]), str(datetime.datetime.now()))) + + # Add git info + label = subprocess.check_output(["git", "describe", + "--always"]).strip() + if subprocess.call( + ["git", "branch"], + stderr=subprocess.STDOUT, + stdout=open(os.devnull, 'w')) == 0: + opt_file.write('=== Git info ====\n') + opt_file.write('{}\n'.format(label)) + commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']) + opt_file.write('commit : {}\n'.format(commit.strip())) + + opt_picklename = '{}.pkl'.format(opt_prefix) + opt_picklepath = os.path.join(save_folder, opt_picklename) + with open(opt_picklepath, 'wb') as opt_file: + pickle.dump(opts, opt_file) + if verbose: + print('Saved options to {}'.format(opt_path)) diff --git a/data_processing/common/utils/manopth/manopth/demo.py b/data_processing/common/utils/manopth/manopth/demo.py new file mode 100644 index 0000000..0bca468 --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/demo.py @@ -0,0 +1,59 @@ +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import numpy as np +import torch + +from manopth.manolayer import ManoLayer + + +def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'): + nfull_comps = ncomps + 3 # Add global orientation dims to PCA + random_pcapose = torch.rand(batch_size, nfull_comps) + mano_layer = ManoLayer(mano_root=mano_root) + verts, joints = mano_layer(random_pcapose) + return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces} + + +def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True): + """ + Displays hand batch_idx in batch of hand_info, hand_info as returned by + generate_random_hand + """ + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][ + batch_idx] + if mano_faces is None: + ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1) + else: + mesh = Poly3DCollection(verts[mano_faces], alpha=alpha) + face_color = (141 / 255, 184 / 255, 226 / 255) + edge_color = (50 / 255, 50 / 255, 50 / 255) + mesh.set_edgecolor(edge_color) + mesh.set_facecolor(face_color) + ax.add_collection3d(mesh) + ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r') + cam_equal_aspect_3d(ax, verts.numpy()) + if show: + plt.show() + + +def cam_equal_aspect_3d(ax, verts, flip_x=False): + """ + Centers view on cuboid containing hand and flips y and z axis + and fixes azimuth + """ + extents = np.stack([verts.min(0), verts.max(0)], axis=1) + sz = extents[:, 1] - extents[:, 0] + centers = np.mean(extents, axis=1) + maxsize = max(abs(sz)) + r = maxsize / 2 + if flip_x: + ax.set_xlim(centers[0] + r, centers[0] - r) + else: + ax.set_xlim(centers[0] - r, centers[0] + r) + # Invert y and z axis + ax.set_ylim(centers[1] + r, centers[1] - r) + ax.set_zlim(centers[2] + r, centers[2] - r) diff --git a/data_processing/common/utils/manopth/manopth/manolayer.py b/data_processing/common/utils/manopth/manopth/manolayer.py new file mode 100644 index 0000000..24c6a71 --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/manolayer.py @@ -0,0 +1,273 @@ +import os + +import numpy as np +import torch +from torch.nn import Module + +from mano.webuser.smpl_handpca_wrapper_HAND_only import ready_arguments +from manopth import rodrigues_layer, rotproj, rot6d +from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, + subtract_flat_id, make_list) + + +class ManoLayer(Module): + __constants__ = [ + 'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check', + 'side', 'center_idx', 'joint_rot_mode' + ] + + def __init__(self, + center_idx=None, + flat_hand_mean=True, + ncomps=6, + side='right', + mano_root='mano/models', + use_pca=True, + root_rot_mode='axisang', + joint_rot_mode='axisang', + robust_rot=False): + """ + Args: + center_idx: index of center joint in our computations, + if -1 centers on estimate of palm as middle of base + of middle finger and wrist + flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match + flat hand, else match average hand pose + mano_root: path to MANO pkl files for left and right hand + ncomps: number of PCA components form pose space (<45) + side: 'right' or 'left' + use_pca: Use PCA decomposition for pose space. + joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca + """ + super().__init__() + + self.center_idx = center_idx + self.robust_rot = robust_rot + if root_rot_mode == 'axisang': + self.rot = 3 + else: + self.rot = 6 + self.flat_hand_mean = flat_hand_mean + self.side = side + self.use_pca = use_pca + self.joint_rot_mode = joint_rot_mode + self.root_rot_mode = root_rot_mode + if use_pca: + self.ncomps = ncomps + else: + self.ncomps = 45 + + if side == 'right': + self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl') + elif side == 'left': + self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl') + + smpl_data = ready_arguments(self.mano_path) + + hands_components = smpl_data['hands_components'] + + self.smpl_data = smpl_data + + self.register_buffer('th_betas', + torch.Tensor(smpl_data['betas'].r).unsqueeze(0)) + self.register_buffer('th_shapedirs', + torch.Tensor(smpl_data['shapedirs'].r)) + self.register_buffer('th_posedirs', + torch.Tensor(smpl_data['posedirs'].r)) + self.register_buffer( + 'th_v_template', + torch.Tensor(smpl_data['v_template'].r).unsqueeze(0)) + self.register_buffer( + 'th_J_regressor', + torch.Tensor(np.array(smpl_data['J_regressor'].toarray()))) + self.register_buffer('th_weights', + torch.Tensor(smpl_data['weights'].r)) + self.register_buffer('th_faces', + torch.Tensor(smpl_data['f'].astype(np.int32)).long()) + + # Get hand mean + hands_mean = np.zeros(hands_components.shape[1] + ) if flat_hand_mean else smpl_data['hands_mean'] + hands_mean = hands_mean.copy() + th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0) + if self.use_pca or self.joint_rot_mode == 'axisang': + # Save as axis-angle + self.register_buffer('th_hands_mean', th_hands_mean) + selected_components = hands_components[:ncomps] + self.register_buffer('th_selected_comps', + torch.Tensor(selected_components)) + else: + th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues( + th_hands_mean.view(15, 3)).reshape(15, 3, 3) + self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat) + + # Kinematic chain params + self.kintree_table = smpl_data['kintree_table'] + parents = list(self.kintree_table[0].tolist()) + self.kintree_parents = parents + + def forward(self, + th_pose_coeffs, + th_betas=torch.zeros(1), + th_trans=torch.zeros(1), + root_palm=torch.Tensor([0]), + share_betas=torch.Tensor([0]), + ): + """ + Args: + th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices + th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape + else centers on root joint (9th joint) + root_palm: return palm as hand root instead of wrist + """ + # if len(th_pose_coeffs) == 0: + # return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0) + + batch_size = th_pose_coeffs.shape[0] + # Get axis angle from PCA components and coefficients + if self.use_pca or self.joint_rot_mode == 'axisang': + # Remove global rot coeffs + th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot + + self.ncomps] + if self.use_pca: + # PCA components --> axis angles + th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps) + else: + th_full_hand_pose = th_hand_pose_coeffs + + # Concatenate back global rot + th_full_pose = torch.cat([ + th_pose_coeffs[:, :self.rot], + self.th_hands_mean + th_full_hand_pose + ], 1) + if self.root_rot_mode == 'axisang': + # compute rotation matrixes from axis-angle while skipping global rotation + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose) + root_rot = th_rot_map[:, :9].view(batch_size, 3, 3) + th_rot_map = th_rot_map[:, 9:] + th_pose_map = th_pose_map[:, 9:] + else: + # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6 + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:]) + if self.robust_rot: + root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + assert th_pose_coeffs.dim() == 4, ( + 'When not self.use_pca, ' + 'th_pose_coeffs should have 4 dims, got {}'.format( + th_pose_coeffs.dim())) + assert th_pose_coeffs.shape[2:4] == (3, 3), ( + 'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two' + 'last dims, got {}'.format(th_pose_coeffs.shape[2:4])) + th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs) + th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1) + th_pose_map = subtract_flat_id(th_rot_map) + root_rot = th_pose_rots[:, 0] + + # Full axis angle representation with root joint + if th_betas is None or th_betas.numel() == 1: + th_v_shaped = torch.matmul(self.th_shapedirs, + self.th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( + batch_size, 1, 1) + + else: + if share_betas: + th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10) + th_v_shaped = torch.matmul(self.th_shapedirs, + th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped) + # th_pose_map should have shape 20x135 + + th_v_posed = th_v_shaped + torch.matmul( + self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) + # Final T pose with transformation done ! + + # Global rigid transformation + + root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1) + root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2)) + + all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3) + lev1_idxs = [1, 4, 7, 10, 13] + lev2_idxs = [2, 5, 8, 11, 14] + lev3_idxs = [3, 6, 9, 12, 15] + lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]] + lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]] + lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]] + lev1_j = th_j[:, lev1_idxs] + lev2_j = th_j[:, lev2_idxs] + lev3_j = th_j[:, lev3_idxs] + + # From base to tips + # Get lev1 results + all_transforms = [root_trans.unsqueeze(1)] + lev1_j_rel = lev1_j - root_j.transpose(1, 2) + lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4) + lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt) + all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev2 results + lev2_j_rel = lev2_j - lev1_j + lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt) + all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev3 results + lev3_j_rel = lev3_j - lev2_j + lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt) + all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4)) + + reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15] + th_results = torch.cat(all_transforms, 1)[:, reorder_idxs] + th_results_global = th_results + + joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2) + tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3)) + th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1) + + th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1)) + + th_rest_shape_h = torch.cat([ + th_v_posed.transpose(2, 1), + torch.ones((batch_size, 1, th_v_posed.shape[1]), + dtype=th_T.dtype, + device=th_T.device), + ], 1) + + th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1) + th_verts = th_verts[:, :, :3] + th_jtr = th_results_global[:, :, :3, 3] + # In addition to MANO reference joints we sample vertices on each finger + # to serve as finger tips + if self.side == 'right': + tips = th_verts[:, [745, 317, 444, 556, 673]] + else: + tips = th_verts[:, [745, 317, 445, 556, 673]] + if bool(root_palm): + palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2 + th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1) + th_jtr = torch.cat([th_jtr, tips], 1) + + # Reorder joints to match visualization utilities + th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] + + if th_trans is None or bool(torch.norm(th_trans) == 0): + if self.center_idx is not None: + center_joint = th_jtr[:, self.center_idx].unsqueeze(1) + th_jtr = th_jtr - center_joint + th_verts = th_verts - center_joint + else: + th_jtr = th_jtr + th_trans.unsqueeze(1) + th_verts = th_verts + th_trans.unsqueeze(1) + + # Scale to milimeters + th_verts = th_verts * 1000 + th_jtr = th_jtr * 1000 + return th_verts, th_jtr diff --git a/data_processing/common/utils/manopth/manopth/rodrigues_layer.py b/data_processing/common/utils/manopth/manopth/rodrigues_layer.py new file mode 100644 index 0000000..bb5ac1e --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/rodrigues_layer.py @@ -0,0 +1,89 @@ +""" +This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py +which is part of a PyTorch port of SMPL. +Thanks to Zhang Xiong (MandyMo) for making this great code available on github ! +""" + +import argparse +from torch.autograd import gradcheck +import torch +from torch.autograd import Variable + +from manopth import argutils + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def batch_rodrigues(axisang): + #axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def th_get_axis_angle(vector): + angle = torch.norm(vector, 2, 1) + axes = vector / angle.unsqueeze(1) + return axes, angle + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--cuda', action='store_true') + args = parser.parse_args() + + argutils.print_args(args) + + n_components = 6 + rot = 3 + inputs = torch.rand(args.batch_size, rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + if args.cuda: + inputs = inputs.cuda() + # outputs = batch_rodrigues(inputs) + test_function = gradcheck(batch_rodrigues, (inputs_var, )) + print('batch test passed !') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, )) + print('th_cv2_rod test passed') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_th = gradcheck(th_cv2_rod.apply, (inputs_var, )) + print('th_cv2_rod_id test passed !') diff --git a/data_processing/common/utils/manopth/manopth/rot6d.py b/data_processing/common/utils/manopth/manopth/rot6d.py new file mode 100644 index 0000000..c1d60ef --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/rot6d.py @@ -0,0 +1,71 @@ +import torch + + +def compute_rotation_matrix_from_ortho6d(poses): + """ + Code from + https://github.com/papagina/RotationContinuity + On the Continuity of Rotation Representations in Neural Networks + Zhou et al. CVPR19 + https://zhouyisjtu.github.io/project_rotation/rotation.html + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + z = cross_product(x, y_raw) # batch*3 + z = normalize_vector(z) # batch*3 + y = cross_product(z, x) # batch*3 + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + return matrix + +def robust_compute_rotation_matrix_from_ortho6d(poses): + """ + Instead of making 2nd vector orthogonal to first + create a base that takes into account the two predicted + directions equally + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + y = normalize_vector(y_raw) # batch*3 + middle = normalize_vector(x + y) + orthmid = normalize_vector(x - y) + x = normalize_vector(middle + orthmid) + y = normalize_vector(middle - orthmid) + # Their scalar product should be small ! + # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001 + z = normalize_vector(cross_product(x, y)) + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + # Check for reflection in matrix ! If found, flip last vector TODO + assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0 + return matrix + + +def normalize_vector(v): + batch = v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1)) # batch + v_mag = torch.max(v_mag, v.new([1e-8])) + v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) + v = v/v_mag + return v + + +def cross_product(u, v): + batch = u.shape[0] + i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] + j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] + k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] + + out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) + + return out diff --git a/data_processing/common/utils/manopth/manopth/rotproj.py b/data_processing/common/utils/manopth/manopth/rotproj.py new file mode 100644 index 0000000..91a601d --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/rotproj.py @@ -0,0 +1,21 @@ +import torch + + +def batch_rotprojs(batches_rotmats): + proj_rotmats = [] + for batch_idx, batch_rotmats in enumerate(batches_rotmats): + proj_batch_rotmats = [] + for rot_idx, rotmat in enumerate(batch_rotmats): + # GPU implementation of svd is VERY slow + # ~ 2 10^-3 per hit vs 5 10^-5 on cpu + U, S, V = rotmat.cpu().svd() + rotmat = torch.matmul(U, V.transpose(0, 1)) + orth_det = rotmat.det() + # Remove reflection + if orth_det < 0: + rotmat[:, 2] = -1 * rotmat[:, 2] + + rotmat = rotmat.cuda() + proj_batch_rotmats.append(rotmat) + proj_rotmats.append(torch.stack(proj_batch_rotmats)) + return torch.stack(proj_rotmats) diff --git a/data_processing/common/utils/manopth/manopth/tensutils.py b/data_processing/common/utils/manopth/manopth/tensutils.py new file mode 100644 index 0000000..0c64c78 --- /dev/null +++ b/data_processing/common/utils/manopth/manopth/tensutils.py @@ -0,0 +1,47 @@ +import torch + +from manopth import rodrigues_layer + + +def th_posemap_axisang(pose_vectors): + rot_nb = int(pose_vectors.shape[1] / 3) + pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3) + rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped) + rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9) + pose_maps = subtract_flat_id(rot_mats) + return pose_maps, rot_mats + + +def th_with_zeros(tensor): + batch_size = tensor.shape[0] + padding = tensor.new([0.0, 0.0, 0.0, 1.0]) + padding.requires_grad = False + + concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)] + cat_res = torch.cat(concat_list, 1) + return cat_res + + +def th_pack(tensor): + batch_size = tensor.shape[0] + padding = tensor.new_zeros((batch_size, 4, 3)) + padding.requires_grad = False + pack_list = [padding, tensor] + pack_res = torch.cat(pack_list, 2) + return pack_res + + +def subtract_flat_id(rot_mats): + # Subtracts identity as a flattened tensor + rot_nb = int(rot_mats.shape[1] / 9) + id_flat = torch.eye( + 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat( + rot_mats.shape[0], rot_nb) + # id_flat.requires_grad = False + results = rot_mats - id_flat + return results + + +def make_list(tensor): + # type: (List[int]) -> List[int] + return tensor diff --git a/data_processing/common/utils/manopth/setup.py b/data_processing/common/utils/manopth/setup.py new file mode 100644 index 0000000..fb3ebdf --- /dev/null +++ b/data_processing/common/utils/manopth/setup.py @@ -0,0 +1,45 @@ +from setuptools import find_packages, setup +import warnings + +DEPENDENCY_PACKAGE_NAMES = ["matplotlib", "torch", "tqdm", "numpy", "cv2", + "chumpy"] + + +def check_dependencies(): + missing_dependencies = [] + for package_name in DEPENDENCY_PACKAGE_NAMES: + try: + __import__(package_name) + except ImportError: + missing_dependencies.append(package_name) + + if missing_dependencies: + warnings.warn( + 'Missing dependencies: {}. We recommend you follow ' + 'the installation instructions at ' + 'https://github.com/hassony2/manopth#installation'.format( + missing_dependencies)) + + +with open("README.md", "r") as fh: + long_description = fh.read() + +check_dependencies() + +setup( + name="manopth", + version="0.0.1", + author="Yana Hasson", + author_email="yana.hasson.inria@gmail.com", + packages=find_packages(exclude=('tests',)), + python_requires=">=3.5.0", + description="PyTorch mano layer", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/hassony2/manopth", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU GENERAL PUBLIC LICENSE", + "Operating System :: OS Independent", + ], +) diff --git a/data_processing/common/utils/manopth/test/test_demo.py b/data_processing/common/utils/manopth/test/test_demo.py new file mode 100644 index 0000000..c378d14 --- /dev/null +++ b/data_processing/common/utils/manopth/test/test_demo.py @@ -0,0 +1,12 @@ +import torch + +from manopth.demo import generate_random_hand + + +def test_generate_random_hand(): + batch_size = 3 + hand_info = generate_random_hand(batch_size=batch_size, ncomps=6) + verts = hand_info['verts'] + joints = hand_info['joints'] + assert verts.shape == (batch_size, 778, 3) + assert joints.shape == (batch_size, 21, 3) diff --git a/data_processing/common/utils/occluder.py b/data_processing/common/utils/occluder.py new file mode 100644 index 0000000..e6f66e2 --- /dev/null +++ b/data_processing/common/utils/occluder.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +import functools +import os.path +import random +import sys +import xml.etree.ElementTree +import numpy as np +import matplotlib.pyplot as plt +import skimage.data +import cv2 +import PIL.Image +import pickle + + + + +def load_pascal_occluder(pascal_voc_root_path): + occluders = [] + structuring_element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8)) + + annotation_paths = list_filepaths(os.path.join(pascal_voc_root_path, 'Annotations')) + for annotation_path in annotation_paths: + xml_root = xml.etree.ElementTree.parse(annotation_path).getroot() + is_segmented = (xml_root.find('segmented').text != '0') + + if not is_segmented: + continue + + boxes = [] + for i_obj, obj in enumerate(xml_root.findall('object')): + is_person = (obj.find('name').text == 'person') + is_difficult = (obj.find('difficult').text != '0') + is_truncated = (obj.find('truncated').text != '0') + if not is_difficult and not is_truncated: + bndbox = obj.find('bndbox') + box = [int(bndbox.find(s).text) for s in ['xmin', 'ymin', 'xmax', 'ymax']] + boxes.append((i_obj, box)) + + if not boxes: + continue + + im_filename = xml_root.find('filename').text + seg_filename = im_filename.replace('jpg', 'png') + + im_path = os.path.join(pascal_voc_root_path, 'JPEGImages', im_filename) + seg_path = os.path.join(pascal_voc_root_path, 'SegmentationObject', seg_filename) + + im = np.asarray(PIL.Image.open(im_path)) + labels = np.asarray(PIL.Image.open(seg_path)) + + for i_obj, (xmin, ymin, xmax, ymax) in boxes: + object_mask = (labels[ymin:ymax, xmin:xmax] == i_obj + 1).astype(np.uint8) * 255 + object_image = im[ymin:ymax, xmin:xmax] + if cv2.countNonZero(object_mask) < 500: + # Ignore small objects + continue + + # Reduce the opacity of the mask along the border for smoother blending + eroded = cv2.erode(object_mask, structuring_element) + object_mask[eroded < object_mask] = 192 + object_with_mask = np.concatenate([object_image, object_mask[..., np.newaxis]], axis=-1) + + if object_with_mask.size == 0: + continue + + # Downscale for efficiency + object_with_mask = resize_by_factor(object_with_mask, 0.5) + occluders.append(object_with_mask) + + print("total # of occluders: ", len(occluders)) + return occluders + +def load_coco_person_occluder(data_path, data_split): + img_dir_path = os.path.join(data_path, f'{data_split}2017') + part_seg_path = os.path.join(data_path, 'densepose_output', 'DensePose_maskRCNN_output') + + dp_dict = load_dp_result(part_seg_path, data_split) + print("loaded dp result..., total imgs: ", len(dp_dict.keys())) + from densepose.data.structures import DensePoseResult + from timer import Timer + load_timer = Timer() + + occluders = [] + structuring_element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8)) + for img_name in dp_dict.keys(): + img_path = os.path.join(img_dir_path, img_name) + load_timer.tic() + img = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + img = img[:, :, ::-1].copy() + # img = np.asarray(PIL.Image.open(img_path)) + load_timer.toc() + + dp_outputs = dp_dict[img_name] + + for output in dp_outputs: + encoded_dp = output['dp'] + iuv_arr = DensePoseResult.decode_png_data(*encoded_dp) + _, h, w = iuv_arr.shape + dp_bbox = output['bbox'] + xmin, ymin = int(dp_bbox[0] + 0.5), int(dp_bbox[1] + 0.5) + xmax, ymax = xmin+w, ymin+h + + object_mask = (iuv_arr[0] != 0).astype(np.uint8) * 255 + object_image = img[ymin:ymax, xmin:xmax] + if cv2.countNonZero(object_mask) < 5000: + # Ignore small objects or low resolution objects + continue + + # Reduce the opacity of the mask along the border for smoother blending + eroded = cv2.erode(object_mask, structuring_element) + object_mask[eroded < object_mask] = 192 + object_with_mask = np.concatenate([object_image, object_mask[..., np.newaxis]], axis=-1) + + if object_with_mask.size == 0: + continue + + # Downscale for efficiency + object_with_mask = resize_by_factor(object_with_mask, 0.5) + occluders.append(object_with_mask) + + if len(occluders) > 5000: + break + + print("img load time: ", load_timer.total_time) + print("total # of occluders: ", len(occluders)) + return occluders + +def load_dp_result(part_seg_path, data_split): + print(f'Load DensePose Result of COCO {data_split} set') + data_path = os.path.join(part_seg_path, f'coco_{data_split}.pkl') + with open(data_path, 'rb') as f: + raw_data_list = pickle.load(f) + + data_dict = {} + for rd in raw_data_list: + key = rd['file_name'].split('/')[-1] + scores = rd['scores'] + pred_data_list = [] + for idx in range(len(scores)): + if scores[idx] > 0.5: + pred_data = {} + pred_data['bbox'] = rd['pred_boxes_XYXY'][idx] + pred_data['dp'] = rd['pred_densepose'].results[idx] + pred_data_list.append(pred_data) + data_dict[key] = pred_data_list + return data_dict + +def occlude_with_objects(im, occluders): + """Returns an augmented version of `im`, containing some occluders from the Pascal VOC dataset.""" + + result = im.copy() + width_height = np.asarray([im.shape[1], im.shape[0]]) + count = np.random.randint(1, 5) + + for _ in range(count): + occluder = random.choice(occluders) + im_scale_factor = min(width_height) / max(occluder.shape[:2]) + + random_scale_factor = np.random.uniform(0.2, 0.5) + scale_factor = random_scale_factor * im_scale_factor + + try: + occluder = resize_by_factor(occluder, scale_factor) + except Exception as e: + print("error") + continue + + # center = np.random.uniform([0, 0], width_height) + center = np.random.uniform(width_height/8, width_height/8*7) + paste_over(im_src=occluder, im_dst=result, center=center) + + return result + + +def paste_over(im_src, im_dst, center): + """Pastes `im_src` onto `im_dst` at a specified position, with alpha blending, in place. + Locations outside the bounds of `im_dst` are handled as expected (only a part or none of + `im_src` becomes visible). + Args: + im_src: The RGBA image to be pasted onto `im_dst`. Its size can be arbitrary. + im_dst: The target image. + alpha: A float (0.0-1.0) array of the same size as `im_src` controlling the alpha blending + at each pixel. Large values mean more visibility for `im_src`. + center: coordinates in `im_dst` where the center of `im_src` should be placed. + """ + + width_height_src = np.asarray([im_src.shape[1], im_src.shape[0]]) + width_height_dst = np.asarray([im_dst.shape[1], im_dst.shape[0]]) + + center = np.round(center).astype(np.int32) + raw_start_dst = center - width_height_src // 2 + raw_end_dst = raw_start_dst + width_height_src + + start_dst = np.clip(raw_start_dst, 0, width_height_dst) + end_dst = np.clip(raw_end_dst, 0, width_height_dst) + region_dst = im_dst[start_dst[1]:end_dst[1], start_dst[0]:end_dst[0]] + + start_src = start_dst - raw_start_dst + end_src = width_height_src + (end_dst - raw_end_dst) + region_src = im_src[start_src[1]:end_src[1], start_src[0]:end_src[0]] + color_src = region_src[..., 0:3] + alpha = region_src[..., 3:].astype(np.float32)/255 + + im_dst[start_dst[1]:end_dst[1], start_dst[0]:end_dst[0]] = ( + alpha * color_src + (1 - alpha) * region_dst) + + return im_dst + + +def resize_by_factor(im, factor): + """Returns a copy of `im` resized by `factor`, using bilinear interp for up and area interp + for downscaling. + """ + new_size = tuple(np.round(np.array([im.shape[1], im.shape[0]]) * factor).astype(int)) + interp = cv2.INTER_LINEAR if factor > 1.0 else cv2.INTER_AREA + return cv2.resize(im, new_size, fx=factor, fy=factor, interpolation=interp) + + +def list_filepaths(dirpath): + names = os.listdir(dirpath) + paths = [os.path.join(dirpath, name) for name in names] + return sorted(filter(os.path.isfile, paths)) + + + +def main(): + """Demo of how to use the code""" + + # path = 'something/something/VOCtrainval_11-May-2012/VOCdevkit/VOC2012' + path = sys.argv[1] + + print('Loading occluders from Pascal VOC dataset...') + occluders = load_pascal_occluder(pascal_voc_root_path=path) + print('Found {} suitable objects'.format(len(occluders))) + + original_im = cv2.resize(skimage.data.astronaut(), (256, 256)) + fig, axarr = plt.subplots(3, 3, figsize=(7, 7)) + for ax in axarr.ravel(): + occluded_im = occlude_with_objects(original_im, occluders) + ax.imshow(occluded_im, interpolation="none") + ax.axis('off') + + fig.tight_layout(h_pad=0) + # plt.savefig('examples.jpg', dpi=150, bbox_inches='tight') + plt.show() + + +if __name__ == '__main__': + dp_path = '/home/redarknight/projects/detectron2/projects/DensePose/' + sys.path.insert(0, dp_path) + occluder = load_coco_person_occluder('/media/disk2/hongsuk/data/COCO/2017/', data_split='train') + # img = occlude_with_objects(dummy, occluder) diff --git a/data_processing/common/utils/posefix.py b/data_processing/common/utils/posefix.py new file mode 100644 index 0000000..e577f55 --- /dev/null +++ b/data_processing/common/utils/posefix.py @@ -0,0 +1,298 @@ +import math +import random +import numpy as np +from easydict import EasyDict as edict + +# coco joints +kps_sigmas = np.array([ + .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, + .87, .89, .89]) / 10.0 +num_kps = 17 +kps_symmetry = ((1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16)) + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def replace_joint_img(joint_img_coco, bbox, near_joints, num_overlap, trans): + xmin, ymin, xmax, ymax = bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3] + pt1 = affine_transform(np.array([xmin, ymin]), trans) + pt2 = affine_transform(np.array([xmax, ymin]), trans) + pt3 = affine_transform(np.array([xmax, ymax]), trans) + area = math.sqrt(pow(pt2[0] - pt1[0], 2) + pow(pt2[1] - pt1[1], 2)) * math.sqrt( + pow(pt3[0] - pt2[0], 2) + pow(pt3[1] - pt2[1], 2)) + joint_img_coco[:17, :] = synthesize_pose(joint_img_coco[:17, :], near_joints[:, :17, :], area, num_overlap) + return joint_img_coco + + +def synthesize_pose(joints, near_joints, area, num_overlap): + def get_dist_wrt_ks(ks, area): + vars = (kps_sigmas * 2) ** 2 + return np.sqrt(-2 * area * vars * np.log(ks)) + + ks_10_dist = get_dist_wrt_ks(0.10, area) + ks_50_dist = get_dist_wrt_ks(0.50, area) + ks_85_dist = get_dist_wrt_ks(0.85, area) + + synth_joints = joints.copy() + + num_valid_joint = np.sum(joints[:, 2] > 0) + + N = 500 + for j in range(num_kps): + + # source keypoint position candidates to generate error on that (gt, swap, inv, swap+inv) + coord_list = [] + # on top of gt + gt_coord = np.expand_dims(synth_joints[j, :2], 0) + coord_list.append(gt_coord) + # on top of swap gt + swap_coord = near_joints[near_joints[:, j, 2] > 0, j, :2] + coord_list.append(swap_coord) + # on top of inv gt, swap inv gt + pair_exist = False + for (q, w) in kps_symmetry: + if j == q or j == w: + if j == q: + pair_idx = w + else: + pair_idx = q + pair_exist = True + if pair_exist and (joints[pair_idx, 2] > 0): + inv_coord = np.expand_dims(synth_joints[pair_idx, :2], 0) + coord_list.append(inv_coord) + else: + coord_list.append(np.empty([0, 2])) + + if pair_exist: + swap_inv_coord = near_joints[near_joints[:, pair_idx, 2] > 0, pair_idx, :2] + coord_list.append(swap_inv_coord) + else: + coord_list.append(np.empty([0, 2])) + + tot_coord_list = np.concatenate(coord_list) + + assert len(coord_list) == 4 + + # jitter error + synth_jitter = np.zeros(3) + if num_valid_joint <= 10: + if j == 0 or (j >= 13 and j <= 16): # nose, ankle, knee + jitter_prob = 0.15 + elif (j >= 1 and j <= 10): # ear, eye, upper body + jitter_prob = 0.20 + else: # hip + jitter_prob = 0.25 + else: + if j == 0 or (j >= 13 and j <= 16): # nose, ankle, knee + jitter_prob = 0.10 + elif (j >= 1 and j <= 10): # ear, eye, upper body + jitter_prob = 0.15 + else: # hip + jitter_prob = 0.20 + angle = np.random.uniform(0, 2 * math.pi, [N]) + r = np.random.uniform(ks_85_dist[j], ks_50_dist[j], [N]) + jitter_idx = 0 # gt + x = tot_coord_list[jitter_idx][0] + r * np.cos(angle) + y = tot_coord_list[jitter_idx][1] + r * np.sin(angle) + dist_mask = True + for i in range(len(tot_coord_list)): + if i == jitter_idx: + continue + dist_mask = np.logical_and(dist_mask, + np.sqrt((tot_coord_list[i][0] - x) ** 2 + (tot_coord_list[i][1] - y) ** 2) > r) + x = x[dist_mask].reshape(-1) + y = y[dist_mask].reshape(-1) + if len(x) > 0: + rand_idx = random.randrange(0, len(x)) + synth_jitter[0] = x[rand_idx] + synth_jitter[1] = y[rand_idx] + synth_jitter[2] = 1 + + # miss error + synth_miss = np.zeros(3) + if num_valid_joint <= 5: + if j >= 0 and j <= 4: # face + miss_prob = 0.15 + elif j == 5 or j == 6 or j == 15 or j == 16: # shoulder, ankle + miss_prob = 0.20 + else: # other parts + miss_prob = 0.25 + elif num_valid_joint <= 10: + if j >= 0 and j <= 4: # face + miss_prob = 0.10 + elif j == 5 or j == 6 or j == 15 or j == 16: # shoulder, ankle + miss_prob = 0.13 + else: # other parts + miss_prob = 0.15 + else: + if j >= 0 and j <= 4: # face + miss_prob = 0.02 + elif j == 5 or j == 6 or j == 15 or j == 16: # shoulder, ankle + miss_prob = 0.05 + else: # other parts + miss_prob = 0.10 + + miss_pt_list = [] + for miss_idx in range(len(tot_coord_list)): + angle = np.random.uniform(0, 2 * math.pi, [4 * N]) + r = np.random.uniform(ks_50_dist[j], ks_10_dist[j], [4 * N]) + x = tot_coord_list[miss_idx][0] + r * np.cos(angle) + y = tot_coord_list[miss_idx][1] + r * np.sin(angle) + dist_mask = True + for i in range(len(tot_coord_list)): + if i == miss_idx: + continue + dist_mask = np.logical_and(dist_mask, + np.sqrt((tot_coord_list[i][0] - x) ** 2 + (tot_coord_list[i][1] - y) ** 2) > + ks_50_dist[j]) + x = x[dist_mask].reshape(-1) + y = y[dist_mask].reshape(-1) + if len(x) > 0: + if miss_idx == 0: + coord = np.transpose(np.vstack([x, y]), [1, 0]) + miss_pt_list.append(coord) + else: + rand_idx = np.random.choice(range(len(x)), size=len(x) // 4) + x = np.take(x, rand_idx) + y = np.take(y, rand_idx) + coord = np.transpose(np.vstack([x, y]), [1, 0]) + miss_pt_list.append(coord) + if len(miss_pt_list) > 0: + miss_pt_list = np.concatenate(miss_pt_list, axis=0).reshape(-1, 2) + rand_idx = random.randrange(0, len(miss_pt_list)) + synth_miss[0] = miss_pt_list[rand_idx][0] + synth_miss[1] = miss_pt_list[rand_idx][1] + synth_miss[2] = 1 + + # inversion prob + synth_inv = np.zeros(3) + if j <= 4: # face + inv_prob = 0.01 + elif j >= 5 and j <= 10: # upper body + inv_prob = 0.03 + else: # lower body + inv_prob = 0.06 + if pair_exist and joints[pair_idx, 2] > 0: + angle = np.random.uniform(0, 2 * math.pi, [N]) + r = np.random.uniform(0, ks_50_dist[j], [N]) + inv_idx = (len(coord_list[0]) + len(coord_list[1])) + x = tot_coord_list[inv_idx][0] + r * np.cos(angle) + y = tot_coord_list[inv_idx][1] + r * np.sin(angle) + dist_mask = True + for i in range(len(tot_coord_list)): + if i == inv_idx: + continue + dist_mask = np.logical_and(dist_mask, np.sqrt( + (tot_coord_list[i][0] - x) ** 2 + (tot_coord_list[i][1] - y) ** 2) > r) + x = x[dist_mask].reshape(-1) + y = y[dist_mask].reshape(-1) + if len(x) > 0: + rand_idx = random.randrange(0, len(x)) + synth_inv[0] = x[rand_idx] + synth_inv[1] = y[rand_idx] + synth_inv[2] = 1 + + # swap prob + synth_swap = np.zeros(3) + swap_exist = (len(coord_list[1]) > 0) or (len(coord_list[3]) > 0) + if (num_valid_joint <= 10 and num_overlap > 0) or (num_valid_joint <= 15 and num_overlap >= 3): + if j >= 0 and j <= 4: # face + swap_prob = 0.02 + elif j >= 5 and j <= 10: # upper body + swap_prob = 0.15 + else: # lower body + swap_prob = 0.10 + else: + if j >= 0 and j <= 4: # face + swap_prob = 0.01 + elif j >= 5 and j <= 10: # upper body + swap_prob = 0.06 + else: # lower body + swap_prob = 0.03 + if swap_exist: + + swap_pt_list = [] + for swap_idx in range(len(tot_coord_list)): + if swap_idx == 0 or swap_idx == len(coord_list[0]) + len(coord_list[1]): + continue + angle = np.random.uniform(0, 2 * math.pi, [N]) + r = np.random.uniform(0, ks_50_dist[j], [N]) + x = tot_coord_list[swap_idx][0] + r * np.cos(angle) + y = tot_coord_list[swap_idx][1] + r * np.sin(angle) + dist_mask = True + for i in range(len(tot_coord_list)): + if i == 0 or i == len(coord_list[0]) + len(coord_list[1]): + dist_mask = np.logical_and(dist_mask, np.sqrt( + (tot_coord_list[i][0] - x) ** 2 + (tot_coord_list[i][1] - y) ** 2) > r) + x = x[dist_mask].reshape(-1) + y = y[dist_mask].reshape(-1) + if len(x) > 0: + coord = np.transpose(np.vstack([x, y]), [1, 0]) + swap_pt_list.append(coord) + if len(swap_pt_list) > 0: + swap_pt_list = np.concatenate(swap_pt_list, axis=0).reshape(-1, 2) + rand_idx = random.randrange(0, len(swap_pt_list)) + synth_swap[0] = swap_pt_list[rand_idx][0] + synth_swap[1] = swap_pt_list[rand_idx][1] + synth_swap[2] = 1 + + # TEMP + # jitter_prob, miss_prob, inv_prob, swap_prob = jitter_prob * 0.5, miss_prob * 0.5, inv_prob * 0.5, swap_prob + + # good prob + synth_good = np.zeros(3) + good_prob = 1 - (jitter_prob + miss_prob + inv_prob + swap_prob) + assert good_prob >= 0 + angle = np.random.uniform(0, 2 * math.pi, [N // 4]) + r = np.random.uniform(0, ks_85_dist[j], [N // 4]) + good_idx = 0 # gt + x = tot_coord_list[good_idx][0] + r * np.cos(angle) + y = tot_coord_list[good_idx][1] + r * np.sin(angle) + dist_mask = True + for i in range(len(tot_coord_list)): + if i == good_idx: + continue + dist_mask = np.logical_and(dist_mask, + np.sqrt((tot_coord_list[i][0] - x) ** 2 + (tot_coord_list[i][1] - y) ** 2) > r) + x = x[dist_mask].reshape(-1) + y = y[dist_mask].reshape(-1) + if len(x) > 0: + rand_idx = random.randrange(0, len(x)) + synth_good[0] = x[rand_idx] + synth_good[1] = y[rand_idx] + synth_good[2] = 1 + + if synth_jitter[2] == 0: + jitter_prob = 0 + if synth_inv[2] == 0: + inv_prob = 0 + if synth_swap[2] == 0: + swap_prob = 0 + if synth_miss[2] == 0: + miss_prob = 0 + if synth_good[2] == 0: + good_prob = 0 + + normalizer = jitter_prob + miss_prob + inv_prob + swap_prob + good_prob + if normalizer == 0: + synth_joints[j] = 0 + continue + + jitter_prob = jitter_prob / normalizer + miss_prob = miss_prob / normalizer + inv_prob = inv_prob / normalizer + swap_prob = swap_prob / normalizer + good_prob = good_prob / normalizer + + prob_list = [jitter_prob, miss_prob, inv_prob, swap_prob, good_prob] + synth_list = [synth_jitter, synth_miss, synth_inv, synth_swap, synth_good] + sampled_idx = np.random.choice(5, 1, p=prob_list)[0] + synth_joints[j] = synth_list[sampled_idx] + + assert synth_joints[j, 2] != 0 + + return synth_joints diff --git a/data_processing/common/utils/preprocessing.py b/data_processing/common/utils/preprocessing.py new file mode 100644 index 0000000..9a99b6e --- /dev/null +++ b/data_processing/common/utils/preprocessing.py @@ -0,0 +1,306 @@ +import numpy as np +import cv2 +import random +from config import cfg +import math + + + +def load_img(path, order='RGB'): + img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + if not isinstance(img, np.ndarray): + raise IOError("Fail to read %s" % path) + + if order=='RGB': + img = img[:,:,::-1].copy() + + img = img.astype(np.float32) + return img + +def get_bbox(joint_img, joint_valid): + + x_img, y_img = joint_img[:,0], joint_img[:,1] + x_img = x_img[joint_valid==1]; y_img = y_img[joint_valid==1]; + xmin = min(x_img); ymin = min(y_img); xmax = max(x_img); ymax = max(y_img); + + x_center = (xmin+xmax)/2.; width = xmax-xmin; + xmin = x_center - 0.5*width*1.2 + xmax = x_center + 0.5*width*1.2 + + y_center = (ymin+ymax)/2.; height = ymax-ymin; + ymin = y_center - 0.5*height*1.2 + ymax = y_center + 0.5*height*1.2 + + bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32) + return bbox + +def compute_iou(src_roi, dst_roi): + # IoU calculate with GTs + xmin = np.maximum(dst_roi[:, 0], src_roi[:, 0]) + ymin = np.maximum(dst_roi[:, 1], src_roi[:, 1]) + xmax = np.minimum(dst_roi[:, 0] + dst_roi[:, 2], src_roi[:, 0] + src_roi[:, 2]) + ymax = np.minimum(dst_roi[:, 1] + dst_roi[:, 3], src_roi[:, 1] + src_roi[:, 3]) + + interArea = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) + + boxAArea = dst_roi[:, 2] * dst_roi[:, 3] + boxBArea = np.tile(src_roi[:, 2] * src_roi[:, 3], (len(dst_roi), 1)) + sumArea = boxAArea + boxBArea + + iou = interArea / (sumArea - interArea + 1e-5) + + return iou + +# def trunc_bbox(bbox): +# if False and random.random() >= 0.3: +# return bbox +# else: +# x, y, w, h = bbox +# x_aug_range, y_aug_range = w/2, h/2 +# x_aug, y_aug = random.random() * x_aug_range, random.random() * y_aug_range +# +# if random.random() <= 0.5: +# x, y = x+x_aug, y+y_aug +# else: # good +# w, h = w-x_aug, h-y_aug +# +# return [x,y,w,h] + +def trunc_tight_bbox(tight_bbox, img, is_full_body): + xmin, ymin, width, height = tight_bbox + xmax = xmin + width + ymax = ymin + height + + height = height * 1.2 + y_center = (ymin + ymax) / 2 + + ymin = y_center - 0.5 * height + ymax = y_center + 0.5 * height + + if is_full_body: + + crop_half_bottom = random.random()<0.8 + else: + crop_half_bottom = False + + ymin = ymin + height * 0.1 * random.random() # 0.0 ~ 0.1 + if crop_half_bottom: # for is_full_body, we only preserve its upper body (or crop the bottom body) + cropped_height = height * 0.25 + height * 0.25 * random.random() # 0.25 ~ 0.5 + ymax = ymin + cropped_height # 0.25 ~ 0.6 + # lower + else: # prob_preserve_more_than_half + cropped_height = height * 0.5 + height * 0.3 * random.random() # 0.5 ~ 0.8 + ymax = ymin + cropped_height # 0.5 ~ 0.9 + + tight_bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32) + + # Since we crop the tight bbox to simulate upper body, we need to set the lower part to 0 + img[int(ymax):, :, :] = 0 + img[:int(ymin), :, :] = 0 + + return tight_bbox, img + +def process_bbox(bbox, img_width, img_height, is_3dpw_test=False): + # sanitize bboxes + x, y, w, h = bbox + x1 = np.max((0, x)) + y1 = np.max((0, y)) + x2 = np.min((img_width - 1, x1 + np.max((0, w - 1)))) + y2 = np.min((img_height - 1, y1 + np.max((0, h - 1)))) + if is_3dpw_test: + bbox = np.array([x1, y1, x2-x1, y2-y1], dtype=np.float32) + else: + if w*h > 0 and x2 >= x1 and y2 >= y1: + bbox = np.array([x1, y1, x2-x1, y2-y1], dtype=np.float32) + else: + return None + + # aspect ratio preserving bbox + w = bbox[2] + h = bbox[3] + c_x = bbox[0] + w/2. + c_y = bbox[1] + h/2. + aspect_ratio = cfg.input_img_shape[1]/cfg.input_img_shape[0] + if w > aspect_ratio * h: + h = w / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + bbox[2] = w*1.25 + bbox[3] = h*1.25 + bbox[0] = c_x - bbox[2]/2. + bbox[1] = c_y - bbox[3]/2. + + return bbox +def get_aug_config(exclude_flip): + scale_factor = 0.25 + rot_factor = 30 + color_factor = 0.2 + + scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0 + rot = np.clip(np.random.randn(), -2.0, + 2.0) * rot_factor if random.random() <= 0.2 else 0 + c_up = 1.0 + color_factor + c_low = 1.0 - color_factor + color_scale = np.array([random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]) + if exclude_flip: + do_flip = False + else: + do_flip = random.random() <= 0.5 + + do_crop_bbox = random.random() <= 0.7 + + return scale, rot, color_scale, do_flip , do_crop_bbox + + +def augmentation(img, tight_bbox, data_split, exclude_flip=False,is_full_body=False): + if data_split == 'train': + scale, rot, color_scale, do_flip, do_crop_bbox = get_aug_config(exclude_flip, ) + else: + scale, rot, color_scale, do_flip, do_crop_bbox = 1.0, 0.0, np.array([1, 1, 1]), False, False + + orig_tight_bbox = tight_bbox.copy() + if do_crop_bbox: + tight_bbox, img = trunc_tight_bbox(tight_bbox, img, is_full_body=is_full_body) + + bbox = process_bbox(tight_bbox, img.shape[1], img.shape[0]) + + + + ''' + bbox_viz = cv2.rectangle(img.copy(), (int(orig_tight_bbox[0]), int(orig_tight_bbox[1])), (int(orig_tight_bbox[0]+orig_tight_bbox[2]), int(orig_tight_bbox[1]+orig_tight_bbox[3])), (0,255,0), 2) + bbox_viz = cv2.rectangle(bbox_viz, (int(tight_bbox[0]), int(tight_bbox[1])), (int(tight_bbox[0]+tight_bbox[2]), int(tight_bbox[1]+tight_bbox[3])), (0,0,255), 2) + bbox_viz = cv2.rectangle(bbox_viz, (int(bbox[0]), int(bbox[1])), (int(bbox[0]+bbox[2]), int(bbox[1]+bbox[3])), (255,0,0), 2) + cv2.imshow('bbox', bbox_viz/255) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + #''' + img, trans, inv_trans = generate_patch_image(img, bbox, scale, rot, do_flip, cfg.input_img_shape) + + ''' + cv2.imshow('aug', img/255) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + #''' + + img = np.clip(img * color_scale[None, None, :], 0, 255) + return img, trans, inv_trans, rot, do_flip, bbox + + +def generate_patch_image(cvimg, bbox, scale, rot, do_flip, out_shape, enable_padding=False): + img = cvimg.copy() + img_height, img_width, img_channels = img.shape + + + + bb_c_x = float(bbox[0] + 0.5*bbox[2]) + bb_c_y = float(bbox[1] + 0.5*bbox[3]) + bb_width = float(bbox[2]) + bb_height = float(bbox[3]) + + if do_flip: + img = img[:, ::-1, :] + bb_c_x = img_width - bb_c_x - 1 + + if enable_padding and (bbox[0]<0 or bbox[1]<0 or bbox[0]+bbox[2]>img_width or bbox[1]+bbox[3]>img_height): + assert do_flip == False + trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot) + # print('trans:',trans.shape,trans) + # img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR) + # img_patch = img_patch.astype(np.float32) + inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot, + inv=True) + # reflection padding + # top, bottom, left, right + padding_top = max(int(-bbox[1]),0) + padding_bottom = max(int(bbox[1]+bbox[3]-img_height),0) + padding_left = max(int(-bbox[0]),0) + padding_right = max(int(bbox[0]+bbox[2]-img_width),0) + img_padding = cv2.copyMakeBorder(img, padding_top, padding_bottom, padding_left, padding_right, cv2.BORDER_REFLECT) + #print(img_padding.shape,np.pad(img.astype(np.float32), ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), 'reflect').shape) + blur_size = int(img.shape[0]//512*5)//2*2 +1 + + img_padding = img_padding.astype(np.float32) + h, w, _ = img_padding.shape + y, x, _ = np.ogrid[:h, :w, :1] + pad = [padding_left+1e-6, padding_top+1e-6, padding_right+1e-6, padding_bottom+1e-6] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + + low_res = cv2.resize(img_padding, (0, 0), fx=0.1, fy=0.1, interpolation=cv2.INTER_AREA) + # blur = qsize * 0.02 * 0.1 + low_res = cv2.GaussianBlur(low_res, (blur_size, blur_size), 0) + low_res = cv2.resize(low_res, (img_padding.shape[1], img_padding.shape[0]), interpolation=cv2.INTER_LANCZOS4).astype(np.float32) + # cv2.imshow('low_res', cv2.resize(low_res, (0, 0), fx=0.5, fy=0.5).astype(np.uint8)) + # cv2.waitKey(0) + img_padding += (low_res - img_padding) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + median = cv2.resize(img_padding.astype(np.uint8), (0, 0), fx=0.1, fy=0.1, interpolation=cv2.INTER_AREA) + median = np.median(median, axis=(0, 1)) + img_padding += (median - img_padding) * np.clip(mask, 0.0, 1.0) + img_padding = np.uint8(np.clip(np.rint(img_padding), 0, 255)) + + # cv2.imshow('img_padding', cv2.resize(img_padding, (0, 0), fx=0.5, fy=0.5).astype(np.uint8)) + # cv2.waitKey(0) + + temp_bbox = np.array([padding_left+bbox[0], padding_top+bbox[1], bbox[2], bbox[3]]) + temp_bb_c_x = float(temp_bbox[0] + 0.5 * temp_bbox[2]) + temp_bb_c_y = float(temp_bbox[1] + 0.5 * temp_bbox[3]) + temp_bb_width = float(temp_bbox[2]) + temp_bb_height = float(temp_bbox[3]) + temp_trans = gen_trans_from_patch_cv(temp_bb_c_x, temp_bb_c_y, temp_bb_width,temp_bb_height, out_shape[1], out_shape[0], scale, rot) + img_patch = cv2.warpAffine(img_padding, temp_trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR) + + else: + trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot) + # print('trans:',trans.shape,trans) + img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR) + img_patch = img_patch.astype(np.float32) + inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot, + inv=True) + + return img_patch, trans, inv_trans + +def rotate_2d(pt_2d, rot_rad): + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + +def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False): + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.array([c_x, c_y], dtype=np.float32) + + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + trans = trans.astype(np.float32) + return trans + diff --git a/data_processing/common/utils/renderer.py b/data_processing/common/utils/renderer.py new file mode 100644 index 0000000..13021ea --- /dev/null +++ b/data_processing/common/utils/renderer.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de +import math +import trimesh +import pyrender +import numpy as np +from pyrender.constants import RenderFlags + + +class WeakPerspectiveCamera(pyrender.Camera): + def __init__(self, + scale, + translation, + znear=pyrender.camera.DEFAULT_Z_NEAR, + zfar=None, + name=None): + super(WeakPerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P + + +class Renderer: + def __init__(self, face, resolution=(224,224), orig_img=False, wireframe=False): + self.resolution = resolution + + self.faces = face + self.orig_img = orig_img + self.wireframe = wireframe + self.renderer = pyrender.OffscreenRenderer( + viewport_width=self.resolution[0], + viewport_height=self.resolution[1], + point_size=1.0 + ) + + # set the scene + self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + + # light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=0.8) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) + + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [0, 1, 1] + self.scene.add(light, pose=light_pose) + + light_pose[:3, 3] = [1, 1, 2] + self.scene.add(light, pose=light_pose) + + def render(self, img, verts, cam, angle=None, axis=None, mesh_filename=None, color=[1.0, 1.0, 0.9], rotate=False): + + mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False) + + Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) + mesh.apply_transform(Rx) + + if rotate: + rot = trimesh.transformations.rotation_matrix( + np.radians(60), [0, 1, 0]) + mesh.apply_transform(rot) + + if mesh_filename is not None: + mesh.export(mesh_filename) + + if angle and axis: + R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) + mesh.apply_transform(R) + + sx, sy, tx, ty = cam + + camera = WeakPerspectiveCamera( + scale=[sx, sy], + translation=[tx, ty], + zfar=1000. + ) + + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + smooth=True, + wireframe=True, + roughnessFactor=1.0, + emissiveFactor=(0.1, 0.1, 0.1), + baseColorFactor=(color[0], color[1], color[2], 1.0) + ) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.2, + # alphaMode='OPAQUE', + # baseColorFactor=(0.8, 0.3, 0.3, 1.0)) + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + mesh_node = self.scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + cam_node = self.scene.add(camera, pose=camera_pose) + + if self.wireframe: + render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME + else: + render_flags = RenderFlags.RGBA + + rgb, depth = self.renderer.render(self.scene, flags=render_flags) + valid_mask = (depth > 0)[:, :, np.newaxis] + output_img = rgb * valid_mask + (1 - valid_mask) * img + image = output_img.astype(np.uint8) + + self.scene.remove_node(mesh_node) + self.scene.remove_node(cam_node) + + return image diff --git a/data_processing/common/utils/smpl.py b/data_processing/common/utils/smpl.py new file mode 100644 index 0000000..be4e125 --- /dev/null +++ b/data_processing/common/utils/smpl.py @@ -0,0 +1,66 @@ +import numpy as np +import torch +import os.path as osp +import json +from config import cfg + +import sys +sys.path.insert(0, cfg.smpl_path) +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +from utils.transforms import build_adj, normalize_adj, transform_joint_to_other_db + + +class SMPL(object): + def __init__(self): + self.layer = {'neutral': self.get_layer(), + #'male': self.get_layer('male'), 'female': self.get_layer('female') + } + self.vertex_num = 6890 + self.face = self.layer['neutral'].th_faces.numpy() + self.joint_regressor = self.layer['neutral'].th_J_regressor.numpy() + self.shape_param_dim = 10 + self.vposer_code_dim = 32 + + # add nose, L/R eye, L/R ear, + self.face_kps_vertex = (331, 2802, 6262, 3489, 3990) # mesh vertex idx + nose_onehot = np.array([1 if i == 331 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + left_eye_onehot = np.array([1 if i == 2802 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + right_eye_onehot = np.array([1 if i == 6262 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + left_ear_onehot = np.array([1 if i == 3489 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + right_ear_onehot = np.array([1 if i == 3990 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + self.joint_regressor = np.concatenate((self.joint_regressor, nose_onehot, left_eye_onehot, right_eye_onehot, left_ear_onehot, right_ear_onehot)) + # add head top + self.joint_regressor_extra = np.load(osp.join(cfg.data_dir, 'J_regressor_extra.npy')) + self.joint_regressor = np.concatenate((self.joint_regressor, self.joint_regressor_extra[3:4, :])).astype(np.float32) + + self.orig_joint_num = 24 + self.joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top + self.joints_name = ('Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') + self.flip_pairs = ( (1,2), (4,5), (7,8), (10,11), (13,14), (16,17), (18,19), (20,21), (22,23) , (25,26), (27,28) ) + self.skeleton = ( (0,1), (1,4), (4,7), (7,10), (0,2), (2,5), (5,8), (8,11), (0,3), (3,6), (6,9), (9,14), (14,17), (17,19), (19, 21), (21,23), (9,13), (13,16), (16,18), (18,20), (20,22), (9,12), (12,24), (24,15), (24,25), (24,26), (25,27), (26,28), (24,29) ) + self.root_joint_idx = self.joints_name.index('Pelvis') + + # joint set for PositionNet prediction + self.graph_joint_num = 15 + self.graph_joints_name = ('Pelvis', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Neck', 'Head_top', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist') + self.graph_flip_pairs = ((1, 2), (3, 4), (5, 6), (9, 10), (11, 12), (13, 14)) + self.graph_skeleton = ((0, 1), (1, 3), (3, 5), (0, 2), (2, 4), (4, 6), (0, 7), (7, 8), (7, 9), (9, 11), (11, 13), (7, 10), (10, 12), (12, 14)) + # construct graph adj + self.graph_adj = self.get_graph_adj() + + def reduce_joint_set(self, joint): + new_joint = [] + for name in self.graph_joints_name: + idx = self.joints_name.index(name) + new_joint.append(joint[:,idx,:]) + new_joint = torch.stack(new_joint,1) + return new_joint + + def get_graph_adj(self): + adj_mat = build_adj(self.graph_joint_num, self.graph_skeleton, self.graph_flip_pairs) + normalized_adj = normalize_adj(adj_mat) + return normalized_adj + + def get_layer(self, gender='neutral'): + return SMPL_Layer(gender=gender, model_root=cfg.smpl_path + '/smplpytorch/native/models') diff --git a/data_processing/common/utils/smplpytorch/LICENSE b/data_processing/common/utils/smplpytorch/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/data_processing/common/utils/smplpytorch/README.md b/data_processing/common/utils/smplpytorch/README.md new file mode 100644 index 0000000..1bf0d9f --- /dev/null +++ b/data_processing/common/utils/smplpytorch/README.md @@ -0,0 +1,67 @@ +SMPL layer for PyTorch +======= + +[SMPL](http://smpl.is.tue.mpg.de) human body [\[1\]](#references) layer for [PyTorch](https://pytorch.org/) (tested with v0.4 and v1.x) +is a differentiable PyTorch layer that deterministically maps from pose and shape parameters to human body joints and vertices. +It can be integrated into any architecture as a differentiable layer to predict body meshes. +The code is adapted from the [manopth](https://github.com/hassony2/manopth) repository by [Yana Hasson](https://github.com/hassony2). + +

+smpl +

+ + +## Setup + +### 1. The `smplpytorch` package +* **Run without installing:** You will need to install the dependencies listed in [environment.yml](environment.yml): + * `conda env update -f environment.yml` in an existing environment, or + * `conda env create -f environment.yml`, for a new `smplpytorch` environment +* **Install:** To import `SMPL_Layer` in another project with `from smplpytorch.pytorch.smpl_layer import SMPL_Layer` do one of the following. + * Option 1: This should automatically install the dependencies. + ``` bash + git clone https://github.com/gulvarol/smplpytorch.git + cd smplpytorch + pip install . + ``` + * Option 2: You can install `smplpytorch` from [PyPI](https://pypi.org/project/smplpytorch/). Additionally, you might need to install [chumpy](https://github.com/hassony2/chumpy.git). + ``` bash + pip install smplpytorch + ``` + +### 2. Download SMPL pickle files + * Download the models from the [SMPL website](http://smpl.is.tue.mpg.de/) by choosing "SMPL for Python users". Note that you need to comply with the [SMPL model license](http://smpl.is.tue.mpg.de/license_model). + * Extract and copy the `models` folder into the `smplpytorch/native/` folder (or set the `model_root` parameter accordingly). + +## Demo + +Forward pass the randomly created pose and shape parameters from the SMPL layer and display the human body mesh and joints: + +`python demo.py` + +## Acknowledgements +The code **largely** builds on the [manopth](https://github.com/hassony2/manopth) repository from [Yana Hasson](https://github.com/hassony2), which implements the [MANO](http://mano.is.tue.mpg.de) hand model [\[2\]](#references) layer. + +The code is a PyTorch port of the original [SMPL](http://smpl.is.tue.mpg.de) model from [chumpy](https://github.com/mattloper/chumpy). It builds on the work of [Loper](https://github.com/mattloper) et al. [\[1\]](#references). + +The code [reuses](https://github.com/gulvarol/smpl/pytorch/rodrigues_layer.py) [part of the code](https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py) by [Zhang Xiong](https://github.com/MandyMo) to compute the rotation utilities. + +If you find this code useful for your research, please cite the original [SMPL](http://smpl.is.tue.mpg.de) publication: + +``` +@article{SMPL:2015, + author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.}, + title = {{SMPL}: A Skinned Multi-Person Linear Model}, + journal = {ACM Trans. Graphics (Proc. SIGGRAPH Asia)}, + number = {6}, + pages = {248:1--248:16}, + volume = {34}, + year = {2015} +} +``` + +## References + +\[1\] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black, "SMPL: A Skinned Multi-Person Linear Model," SIGGRAPH Asia, 2015. + +\[2\] Javier Romero, Dimitrios Tzionas, and Michael J. Black, "Embodied Hands: Modeling and Capturing Hands and Bodies Together," SIGGRAPH Asia, 2017. diff --git a/data_processing/common/utils/smplpytorch/assets/image.png b/data_processing/common/utils/smplpytorch/assets/image.png new file mode 100644 index 0000000..c73c891 Binary files /dev/null and b/data_processing/common/utils/smplpytorch/assets/image.png differ diff --git a/data_processing/common/utils/smplpytorch/demo.py b/data_processing/common/utils/smplpytorch/demo.py new file mode 100644 index 0000000..f506974 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/demo.py @@ -0,0 +1,38 @@ +import torch + +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +from display_utils import display_model + + +if __name__ == '__main__': + cuda = False + batch_size = 1 + + # Create the SMPL layer + smpl_layer = SMPL_Layer( + center_idx=0, + gender='neutral', + model_root='smplpytorch/native/models') + + # Generate random pose and shape parameters + pose_params = torch.rand(batch_size, 72) * 0.2 + shape_params = torch.rand(batch_size, 10) * 0.03 + + # GPU mode + if cuda: + pose_params = pose_params.cuda() + shape_params = shape_params.cuda() + smpl_layer.cuda() + + # Forward from the SMPL layer + verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) + + # Draw output vertices and joints + display_model( + {'verts': verts.cpu().detach(), + 'joints': Jtr.cpu().detach()}, + model_faces=smpl_layer.th_faces, + with_joints=True, + kintree_table=smpl_layer.kintree_table, + savepath='image.png', + show=True) diff --git a/data_processing/common/utils/smplpytorch/display_utils.py b/data_processing/common/utils/smplpytorch/display_utils.py new file mode 100644 index 0000000..2a4b13a --- /dev/null +++ b/data_processing/common/utils/smplpytorch/display_utils.py @@ -0,0 +1,74 @@ +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +# plt.switch_backend('agg') + + +def display_model( + model_info, + model_faces=None, + with_joints=False, + kintree_table=None, + ax=None, + batch_idx=0, + show=True, + savepath=None): + """ + Displays mesh batch_idx in batch of model_info, model_info as returned by + generate_random_model + """ + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + verts, joints = model_info['verts'][batch_idx], model_info['joints'][ + batch_idx] + if model_faces is None: + ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) + else: + mesh = Poly3DCollection(verts[model_faces], alpha=0.2) + face_color = (141 / 255, 184 / 255, 226 / 255) + edge_color = (50 / 255, 50 / 255, 50 / 255) + mesh.set_edgecolor(edge_color) + mesh.set_facecolor(face_color) + ax.add_collection3d(mesh) + if with_joints: + draw_skeleton(joints, kintree_table=kintree_table, ax=ax) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(-0.7, 0.7) + ax.set_ylim(-0.7, 0.7) + ax.set_zlim(-0.7, 0.7) + ax.view_init(azim=-90, elev=100) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + if savepath: + print('Saving figure at {}.'.format(savepath)) + plt.savefig(savepath, bbox_inches='tight', pad_inches=0) + if show: + plt.show() + return ax + + +def draw_skeleton(joints3D, kintree_table, ax=None, with_numbers=True): + if ax is None: + fig = plt.figure(frameon=False) + ax = fig.add_subplot(111, projection='3d') + else: + ax = ax + + colors = [] + left_right_mid = ['r', 'g', 'b'] + kintree_colors = [2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 0, 1, 0, 1] + for c in kintree_colors: + colors += left_right_mid[c] + # For each 24 joint + for i in range(1, kintree_table.shape[1]): + j1 = kintree_table[0][i] + j2 = kintree_table[1][i] + ax.plot([joints3D[j1, 0], joints3D[j2, 0]], + [joints3D[j1, 1], joints3D[j2, 1]], + [joints3D[j1, 2], joints3D[j2, 2]], + color=colors[i], linestyle='-', linewidth=2, marker='o', markersize=5) + if with_numbers: + ax.text(joints3D[j2, 0], joints3D[j2, 1], joints3D[j2, 2], j2) + return ax diff --git a/data_processing/common/utils/smplpytorch/environment.yml b/data_processing/common/utils/smplpytorch/environment.yml new file mode 100644 index 0000000..487c7ed --- /dev/null +++ b/data_processing/common/utils/smplpytorch/environment.yml @@ -0,0 +1,11 @@ +name: smplpytorch + +dependencies: + - opencv + - python=3.7 + - matplotlib + - numpy + - pytorch + - pip + - pip: + - git+https://github.com/hassony2/chumpy.git diff --git a/data_processing/common/utils/smplpytorch/image.png b/data_processing/common/utils/smplpytorch/image.png new file mode 100644 index 0000000..c73c891 Binary files /dev/null and b/data_processing/common/utils/smplpytorch/image.png differ diff --git a/data_processing/common/utils/smplpytorch/setup.py b/data_processing/common/utils/smplpytorch/setup.py new file mode 100644 index 0000000..e0ced55 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/setup.py @@ -0,0 +1,30 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +REQUIREMENTS = [ + "opencv-python", + "matplotlib", + "numpy", + "torch", + "chumpy @ git+ssh://git@github.com/hassony2/chumpy"] + +setuptools.setup( + name="smplpytorch", + version="0.0.1", + author="Gul Varol", + author_email="gulvarols@gmail.com", + python_requires=">=3.5.0", + install_requires=REQUIREMENTS, + description="SMPL human body layer for PyTorch is a differentiable PyTorch layer", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/gulvarol/smplpytorch", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Operating System :: OS Independent", + ], +) diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/__init__.py b/data_processing/common/utils/smplpytorch/smplpytorch/__init__.py new file mode 100644 index 0000000..f3d4e53 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/__init__.py @@ -0,0 +1 @@ +name = "smplpytorch" diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/native/__init__.py b/data_processing/common/utils/smplpytorch/smplpytorch/native/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/native/models/README.md b/data_processing/common/utils/smplpytorch/smplpytorch/native/models/README.md new file mode 100644 index 0000000..9e113e9 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/native/models/README.md @@ -0,0 +1 @@ +Here copy the .pkl model files. diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/__init__.py b/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/posemapper.py b/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/posemapper.py new file mode 100644 index 0000000..88a2ed7 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/posemapper.py @@ -0,0 +1,31 @@ +import chumpy as ch +import numpy as np +import cv2 + + +class Rodrigues(ch.Ch): + dterms = 'rt' + + def compute_r(self): + return cv2.Rodrigues(self.rt.r)[0] + + def compute_dr_wrt(self, wrt): + if wrt is self.rt: + return cv2.Rodrigues(self.rt.r)[1].T + + +def lrotmin(p): + if isinstance(p, np.ndarray): + p = p.ravel()[3:] + return np.concatenate([(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel() for pp in p.reshape((-1, 3))]).ravel() + if p.ndim != 2 or p.shape[1] != 3: + p = p.reshape((-1, 3)) + p = p[1:] + return ch.concatenate([(Rodrigues(pp) - ch.eye(3)).ravel() for pp in p]).ravel() + + +def posemap(s): + if s == 'lrotmin': + return lrotmin + else: + raise Exception('Unknown posemapping: %s' % (str(s),)) diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/serialization.py b/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/serialization.py new file mode 100644 index 0000000..9bf16ee --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/native/webuser/serialization.py @@ -0,0 +1,39 @@ +def ready_arguments(fname_or_dict): + import numpy as np + import pickle + import chumpy as ch + from chumpy.ch import MatVecMult + from smplpytorch.native.webuser.posemapper import posemap + + if not isinstance(fname_or_dict, dict): + dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # dd = pickle.load(open(fname_or_dict, 'rb')) + else: + dd = fname_or_dict + + want_shapemodel = 'shapedirs' in dd + nposeparms = dd['kintree_table'].shape[1] * 3 + + if 'trans' not in dd: + dd['trans'] = np.zeros(3) + if 'pose' not in dd: + dd['pose'] = np.zeros(nposeparms) + if 'shapedirs' in dd and 'betas' not in dd: + dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) + + for s in ['v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', 'betas', 'J']: + if (s in dd) and not hasattr(dd[s], 'dterms'): + dd[s] = ch.array(dd[s]) + + if want_shapemodel: + dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] + v_shaped = dd['v_shaped'] + J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) + J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) + J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) + dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T + dd['v_posed'] = v_shaped + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose'])) + else: + dd['v_posed'] = dd['v_template'] + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose'])) + + return dd diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/__init__.py b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/rodrigues_layer.py b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/rodrigues_layer.py new file mode 100644 index 0000000..46ae478 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/rodrigues_layer.py @@ -0,0 +1,85 @@ +""" +This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py +which is part of a PyTorch port of SMPL. +Thanks to Zhang Xiong (MandyMo) for making this great code available on github ! +""" + +import argparse +from torch.autograd import gradcheck +import torch +from torch.autograd import Variable + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def batch_rodrigues(axisang): + #axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def th_get_axis_angle(vector): + angle = torch.norm(vector, 2, 1) + axes = vector / angle.unsqueeze(1) + return axes, angle + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--cuda', action='store_true') + args = parser.parse_args() + + n_components = 6 + rot = 3 + inputs = torch.rand(args.batch_size, rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + if args.cuda: + inputs = inputs.cuda() + # outputs = batch_rodrigues(inputs) + test_function = gradcheck(batch_rodrigues, (inputs_var, )) + print('batch test passed !') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, )) + print('th_cv2_rod test passed') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_th = gradcheck(th_cv2_rod.apply, (inputs_var, )) + print('th_cv2_rod_id test passed !') diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/smpl_layer.py b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/smpl_layer.py new file mode 100644 index 0000000..5e74741 --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/smpl_layer.py @@ -0,0 +1,156 @@ +import os + +import numpy as np +import torch +from torch.nn import Module + +from smplpytorch.native.webuser.serialization import ready_arguments +from smplpytorch.pytorch import rodrigues_layer +from smplpytorch.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list, subtract_flat_id) + + +class SMPL_Layer(Module): + __constants__ = ['kintree_parents', 'gender', 'center_idx', 'num_joints'] + + def __init__(self, + center_idx=None, + gender='neutral', + model_root='smpl/native/models'): + """ + Args: + center_idx: index of center joint in our computations, + model_root: path to pkl files for the model + gender: 'neutral' (default) or 'female' or 'male' + """ + super().__init__() + + self.center_idx = center_idx + self.gender = gender + + if gender == 'neutral': + self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl') + elif gender == 'female': + self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl') + elif gender == 'male': + self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl') + + smpl_data = ready_arguments(self.model_path) + self.smpl_data = smpl_data + + self.register_buffer('th_betas', + torch.Tensor(smpl_data['betas'].r).unsqueeze(0)) + self.register_buffer('th_shapedirs', + torch.Tensor(smpl_data['shapedirs'].r)) + self.register_buffer('th_posedirs', + torch.Tensor(smpl_data['posedirs'].r)) + self.register_buffer( + 'th_v_template', + torch.Tensor(smpl_data['v_template'].r).unsqueeze(0)) + self.register_buffer( + 'th_J_regressor', + torch.Tensor(np.array(smpl_data['J_regressor'].toarray()))) + self.register_buffer('th_weights', + torch.Tensor(smpl_data['weights'].r)) + self.register_buffer('th_faces', + torch.Tensor(smpl_data['f'].astype(np.int32)).long()) + + # Kinematic chain params + self.kintree_table = smpl_data['kintree_table'] + parents = list(self.kintree_table[0].tolist()) + self.kintree_parents = parents + self.num_joints = len(parents) # 24 + + def forward(self, + th_pose_axisang, + th_betas=torch.zeros(1), + th_trans=torch.zeros(1)): + """ + Args: + th_pose_axisang (Tensor (batch_size x 72)): pose parameters in axis-angle representation + th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters + th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices + """ + + batch_size = th_pose_axisang.shape[0] + # Convert axis-angle representation to rotation matrix rep. + th_pose_rotmat = th_posemap_axisang(th_pose_axisang) + # Take out the first rotmat (global rotation) + root_rot = th_pose_rotmat[:, :9].view(batch_size, 3, 3) + # Take out the remaining rotmats (23 joints) + th_pose_rotmat = th_pose_rotmat[:, 9:] + th_pose_map = subtract_flat_id(th_pose_rotmat) + + # Below does: v_shaped = v_template + shapedirs * betas + # If shape parameters are not provided + if th_betas is None or bool(torch.norm(th_betas) == 0): + th_v_shaped = self.th_v_template + torch.matmul( + self.th_shapedirs, self.th_betas.transpose(1, 0)).permute(2, 0, 1) + th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( + batch_size, 1, 1) + else: + th_v_shaped = self.th_v_template + torch.matmul( + self.th_shapedirs, th_betas.transpose(1, 0)).permute(2, 0, 1) + th_j = torch.matmul(self.th_J_regressor, th_v_shaped) + + # Below does: v_posed = v_shaped + posedirs * pose_map + th_v_posed = th_v_shaped + torch.matmul( + self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) + # Final T pose with transformation done! + + # Global rigid transformation + th_results = [] + + root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1) + th_results.append(th_with_zeros(torch.cat([root_rot, root_j], 2))) + + # Rotate each part + for i in range(self.num_joints - 1): + i_val = int(i + 1) + joint_rot = th_pose_rotmat[:, (i_val - 1) * 9:i_val * + 9].contiguous().view(batch_size, 3, 3) + joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1) + parent = make_list(self.kintree_parents)[i_val] + parent_j = th_j[:, parent, :].contiguous().view(batch_size, 3, 1) + joint_rel_transform = th_with_zeros( + torch.cat([joint_rot, joint_j - parent_j], 2)) + th_results.append( + torch.matmul(th_results[parent], joint_rel_transform)) + th_results_global = th_results + + th_results2 = torch.zeros((batch_size, 4, 4, self.num_joints), + dtype=root_j.dtype, + device=root_j.device) + + for i in range(self.num_joints): + padd_zero = torch.zeros(1, dtype=th_j.dtype, device=th_j.device) + joint_j = torch.cat( + [th_j[:, i], + padd_zero.view(1, 1).repeat(batch_size, 1)], 1) + tmp = torch.bmm(th_results[i], joint_j.unsqueeze(2)) + th_results2[:, :, :, i] = th_results[i] - th_pack(tmp) + + th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1)) + + th_rest_shape_h = torch.cat([ + th_v_posed.transpose(2, 1), + torch.ones((batch_size, 1, th_v_posed.shape[1]), + dtype=th_T.dtype, + device=th_T.device), + ], 1) + + th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1) + th_verts = th_verts[:, :, :3] + th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3] + + # If translation is not provided + if th_trans is None or bool(torch.norm(th_trans) == 0): + if self.center_idx is not None: + center_joint = th_jtr[:, self.center_idx].unsqueeze(1) + th_jtr = th_jtr - center_joint + th_verts = th_verts - center_joint + else: + th_jtr = th_jtr + th_trans.unsqueeze(1) + th_verts = th_verts + th_trans.unsqueeze(1) + + # Vertices and joints in meters + return th_verts, th_jtr diff --git a/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/tensutils.py b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/tensutils.py new file mode 100644 index 0000000..092e60d --- /dev/null +++ b/data_processing/common/utils/smplpytorch/smplpytorch/pytorch/tensutils.py @@ -0,0 +1,53 @@ +import torch + +from smplpytorch.pytorch import rodrigues_layer + + +def th_posemap_axisang(pose_vectors): + ''' + Converts axis-angle to rotmat + pose_vectors (Tensor (batch_size x 72)): pose parameters in axis-angle representation + ''' + rot_nb = int(pose_vectors.shape[1] / 3) + rot_mats = [] + for joint_idx in range(rot_nb): + axis_ang = pose_vectors[:, joint_idx * 3:(joint_idx + 1) * 3] + rot_mat = rodrigues_layer.batch_rodrigues(axis_ang) + rot_mats.append(rot_mat) + + rot_mats = torch.cat(rot_mats, 1) + return rot_mats + + +def th_with_zeros(tensor): + batch_size = tensor.shape[0] + padding = tensor.new([0.0, 0.0, 0.0, 1.0]) + padding.requires_grad = False + + concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)] + cat_res = torch.cat(concat_list, 1) + return cat_res + + +def th_pack(tensor): + batch_size = tensor.shape[0] + padding = tensor.new_zeros((batch_size, 4, 3)) + padding.requires_grad = False + pack_list = [padding, tensor] + pack_res = torch.cat(pack_list, 2) + return pack_res + + +def subtract_flat_id(rot_mats): + # Subtracts identity as a flattened tensor + id_flat = torch.eye( + 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat( + rot_mats.shape[0], 23) # [rot_mats.shape[0], 23x9 = 207] + # id_flat.requires_grad = False + results = rot_mats - id_flat + return results + + +def make_list(tensor): + # type: (List[int]) -> List[int] + return tensor diff --git a/data_processing/common/utils/transforms.py b/data_processing/common/utils/transforms.py new file mode 100644 index 0000000..0d69f60 --- /dev/null +++ b/data_processing/common/utils/transforms.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +from config import cfg +import torchgeometry as tgm +from torch.nn import functional as F + + +def denorm_joints(pose_out_img, body_bb2img_trans): + pose_out_img[:, 0] = pose_out_img[:, 0] / cfg.output_hm_shape[2] * cfg.input_img_shape[1] + pose_out_img[:, 1] = pose_out_img[:, 1] / cfg.output_hm_shape[1] * cfg.input_img_shape[0] + pose_out_img_xy1 = np.concatenate((pose_out_img[:, :2], np.ones_like(pose_out_img[:, :1])), 1) + pose_out_img[:, :2] = np.dot(body_bb2img_trans, pose_out_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2] + + return pose_out_img + +def cam2pixel(cam_coord, f, c): + x = cam_coord[:,0] / cam_coord[:,2] * f[0] + c[0] + y = cam_coord[:,1] / cam_coord[:,2] * f[1] + c[1] + z = cam_coord[:,2] + return np.stack((x,y,z),1) + +def pixel2cam(pixel_coord, f, c): + x = (pixel_coord[:,0] - c[0]) / f[0] * pixel_coord[:,2] + y = (pixel_coord[:,1] - c[1]) / f[1] * pixel_coord[:,2] + z = pixel_coord[:,2] + return np.stack((x,y,z),1) + +def world2cam(world_coord, R, t): + cam_coord = np.dot(R, world_coord.transpose(1,0)).transpose(1,0) + t.reshape(1,3) + return cam_coord + +def cam2world(cam_coord, R, t): + world_coord = np.dot(np.linalg.inv(R), (cam_coord - t.reshape(1,3)).transpose(1,0)).transpose(1,0) + return world_coord + +def rigid_transform_3D(A, B): + n, dim = A.shape + centroid_A = np.mean(A, axis = 0) + centroid_B = np.mean(B, axis = 0) + H = np.dot(np.transpose(A - centroid_A), B - centroid_B) / n + U, s, V = np.linalg.svd(H) + R = np.dot(np.transpose(V), np.transpose(U)) + if np.linalg.det(R) < 0: + s[-1] = -s[-1] + V[2] = -V[2] + R = np.dot(np.transpose(V), np.transpose(U)) + + varP = np.var(A, axis=0).sum() + c = 1/varP * np.sum(s) + + t = -np.dot(c*R, np.transpose(centroid_A)) + np.transpose(centroid_B) + return c, R, t + +def rigid_align(A, B): + c, R, t = rigid_transform_3D(A, B) + A2 = np.transpose(np.dot(c*R, np.transpose(A))) + t + return A2 + +def transform_joint_to_other_db(src_joint, src_name, dst_name): + src_joint_num = len(src_name) + dst_joint_num = len(dst_name) + + new_joint = np.zeros(((dst_joint_num,) + src_joint.shape[1:]), dtype=np.float32) + for src_idx in range(len(src_name)): + name = src_name[src_idx] + if name in dst_name: + dst_idx = dst_name.index(name) + new_joint[dst_idx] = src_joint[src_idx] + + return new_joint + +def build_adj(vertex_num, skeleton, flip_pairs): + adj_matrix = np.zeros((vertex_num, vertex_num)) + for line in skeleton: + adj_matrix[line] = 1 + adj_matrix[line[1], line[0]] = 1 + for pair in flip_pairs: + adj_matrix[pair] = 1 + adj_matrix[pair[1], pair[0]] = 1 + return adj_matrix + +def normalize_adj(adj): + vertex_num = adj.shape[0] + adj_self = adj + np.eye(vertex_num) + D = np.diag(adj_self.sum(0)) + np.spacing(np.array(0)) + _D = 1 / np.sqrt(D) + _D = _D * np.eye(vertex_num) # make diagonal matrix + normalized_adj = np.dot(np.dot(_D, adj_self), _D) + return normalized_adj + +def rot6d_to_axis_angle(x): + batch_size = x.shape[0] + + x = x.view(-1, 3, 2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + rot_mat = torch.stack((b1, b2, b3), dim=-1) # 3x3 rotation matrix + + rot_mat = torch.cat([rot_mat, torch.zeros((batch_size, 3, 1)).cuda().float()], 2) # 3x4 rotation matrix + axis_angle = tgm.rotation_matrix_to_angle_axis(rot_mat).reshape(-1, 3) # axis-angle + axis_angle[torch.isnan(axis_angle)] = 0.0 + return axis_angle + + +def convert_crop_cam_to_orig_img(cam, bbox, img_width, img_height): + ''' + Convert predicted camera from cropped image coordinates + to original image coordinates + :param cam (ndarray, shape=(3,)): weak perspective camera in cropped img coordinates + :param bbox (ndarray, shape=(4,)): bbox coordinates (c_x, c_y, h) + :param img_width (int): original image width + :param img_height (int): original image height + :return: + ''' + cx, cy, h = bbox[:,0], bbox[:,1], bbox[:,2] + hw, hh = img_width / 2., img_height / 2. + sx = cam[:,0] * (1. / (img_width / h)) + sy = cam[:,0] * (1. / (img_height / h)) + tx = ((cx - hw) / hw / sx) + cam[:,1] + ty = ((cy - hh) / hh / sy) + cam[:,2] + orig_cam = np.stack([sx, sy, tx, ty]).T + return orig_cam + diff --git a/data_processing/common/utils/vis.py b/data_processing/common/utils/vis.py new file mode 100644 index 0000000..d0913b7 --- /dev/null +++ b/data_processing/common/utils/vis.py @@ -0,0 +1,268 @@ +import os +import cv2 +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +import matplotlib.pyplot as plt +import matplotlib as mpl +import trimesh +#os.environ['PYOPENGL_PLATFORM'] = 'egl' # comment it if use windows +import pyrender + + +def vis_bbox(img, bbox, alpha=1): + + kp_mask = np.copy(img) + bbox = bbox.astype(np.int32) # x, y, w, h + + b1 = bbox[0], bbox[1] + b2 = bbox[0] + bbox[2], bbox[1] + b3 = bbox[0] + bbox[2], bbox[1] + bbox[3] + b4 = bbox[0], bbox[1] + bbox[3] + + cv2.line(kp_mask, b1, b2, color=(255, 255, 0), thickness=1, lineType=cv2.LINE_AA) + cv2.line(kp_mask, b2, b3, color=(255, 255, 0), thickness=1, lineType=cv2.LINE_AA) + cv2.line(kp_mask, b3, b4, color=(255, 255, 0), thickness=1, lineType=cv2.LINE_AA) + cv2.line(kp_mask, b4, b1, color=(255, 255, 0), thickness=1, lineType=cv2.LINE_AA) + + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +def vis_coco_skeleton(img, kps, kps_lines, alpha=1): + colors = [ + # face + (255/255, 153/255, 51/255), + (255/255, 153/255, 51/255), + (255/255, 153/255, 51/255), + (255/255, 153/255, 51/255), + + # left arm + (102/255, 255/255, 102/255), + (51/255, 255/255, 51/255), + + # right leg + (255 / 255, 102 / 255, 255 / 255), + (255 / 255, 51 / 255, 255 / 255), + + + # left leg + + (255 / 255, 102 / 255, 102 / 255), + (255 / 255, 51 / 255, 51 / 255), + + # shoulder-thorax, hip-pevlis, + (153/255, 255/255, 153/255), # l shoulder - thorax + (153/255, 204/255, 255/255), # r shoulder - thorax + (255/255, 153/255, 153/255), # l hip - pelvis + (255/255, 153/255, 255/255), # r hip -pelvis + + # center body line + (255/255, 204/255, 153/255), + (255/255, 178/255, 102/255), + + # right arm + (102 / 255, 178 / 255, 255 / 255), + (51 / 255, 153 / 255, 255 / 255), + ] + + colors = [[c[2]*255,c[1]*255,c[0]*255] for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + line_thick = 5 #13 + circle_rad = 5 #10 + circle_thick = 5 #7 + + # Draw the keypoints. + for l in range(len(kps_lines)): + i1 = kps_lines[l][0] + i2 = kps_lines[l][1] + p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32) + p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32) + cv2.line( + kp_mask, p1, p2, + color=colors[l], thickness=line_thick, lineType=cv2.LINE_AA) + cv2.circle( + kp_mask, p1, + radius=circle_rad, color=colors[l], thickness=circle_thick, lineType=cv2.LINE_AA) + cv2.circle( + kp_mask, p2, + radius=circle_rad, color=colors[l], thickness=circle_thick, lineType=cv2.LINE_AA) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1, kps_scores=None): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for l in range(len(kps_lines)): + i1 = kps_lines[l][0] + i2 = kps_lines[l][1] + p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32) + p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32) + if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: + cv2.line( + kp_mask, p1, p2, + color=colors[l], thickness=2, lineType=cv2.LINE_AA) + if kps[2, i1] > kp_thresh: + cv2.circle( + kp_mask, p1, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + cv2.putText(kp_mask, str([l][0]), p1, cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[l]) + if kps[2, i2] > kp_thresh: + cv2.circle( + kp_mask, p2, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + cv2.putText(kp_mask, str([l][1]), p2, cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[l]) + if kps_scores is not None: + cv2.putText(kp_mask, str(kps_scores[i2, 0]), p2, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255)) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + +def vis_keypoints(img, kps, alpha=1, kps_vis=None): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for i in range(len(kps)): + p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32) + cv2.circle(kp_mask, p, radius=3, color=colors[i], thickness=-1, lineType=cv2.LINE_AA) + if kps_vis is not None: + cv2.putText(kp_mask, str(kps_vis[i, 0]), p, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255)) + else: + cv2.putText(kp_mask, str(i), p, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255)) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + +def vis_mesh(img, mesh_vertex, alpha=0.5): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(mesh_vertex))] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + mask = np.copy(img) + + # Draw the mesh + for i in range(len(mesh_vertex)): + p = mesh_vertex[i][0].astype(np.int32), mesh_vertex[i][1].astype(np.int32) + cv2.circle(mask, p, radius=1, color=colors[i], thickness=-1, lineType=cv2.LINE_AA) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0) + +def vis_3d_skeleton(kpt_3d, kpt_3d_vis, kps_lines, filename=None): + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)] + colors = [np.array((c[2], c[1], c[0])) for c in colors] + + for l in range(len(kps_lines)): + i1 = kps_lines[l][0] + i2 = kps_lines[l][1] + x = np.array([kpt_3d[i1,0], kpt_3d[i2,0]]) + y = np.array([kpt_3d[i1,1], kpt_3d[i2,1]]) + z = np.array([kpt_3d[i1,2], kpt_3d[i2,2]]) + + if kpt_3d_vis[i1,0] > 0 and kpt_3d_vis[i2,0] > 0: + ax.plot(x, z, -y, c=colors[l], linewidth=2) + if kpt_3d_vis[i1,0] > 0: + ax.scatter(kpt_3d[i1,0], kpt_3d[i1,2], -kpt_3d[i1,1], c=colors[l], marker='o') + if kpt_3d_vis[i2,0] > 0: + ax.scatter(kpt_3d[i2,0], kpt_3d[i2,2], -kpt_3d[i2,1], c=colors[l], marker='o') + + if filename is None: + ax.set_title('3D vis') + else: + ax.set_title(filename) + + ax.set_xlabel('X Label') + ax.set_ylabel('Z Label') + ax.set_zlabel('Y Label') + ax.legend() + + plt.show() + cv2.waitKey(0) + +def save_obj(v, f, file_name='output.obj'): + obj_file = open(file_name, 'w') + for i in range(len(v)): + obj_file.write('v ' + str(v[i][0]) + ' ' + str(v[i][1]) + ' ' + str(v[i][2]) + '\n') + for i in range(len(f)): + obj_file.write('f ' + str(f[i][0]+1) + '/' + str(f[i][0]+1) + ' ' + str(f[i][1]+1) + '/' + str(f[i][1]+1) + ' ' + str(f[i][2]+1) + '/' + str(f[i][2]+1) + '\n') + obj_file.close() + +def render_mesh(img, mesh, face, cam_param, color=(1.0, 1.0, 0.9, 1.0),cam_pose = None): + # mesh + mesh = trimesh.Trimesh(mesh, face) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=color) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + focal, princpt = cam_param['focal'], cam_param['princpt'] + camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + + + if cam_pose is not None: + scene.add(camera, pose=cam_pose) + else: + scene.add(camera) + # scene.add(camera) + # print('camera pose in scene ', scene.get_pose(scene._main_camera_node)) + # renderer + renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) + + # light + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) + # light_pose = np.eye(4) + # light_pose[:3, 3] = np.array([0, -1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([0, 1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([1, 1, 2]) + # scene.add(light, pose=light_pose) + + light_pose = np.eye(4) + light_pose[:3, 3] = np.array([0, 0, -1]) + scene.add(light, pose=light_pose) + + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + light_pose[:3, 3] = np.array([1, 1, -4]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([-1, 0, -1]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([0.2469, 1.8828, -2.4473]) + scene.add(light, pose=light_pose) + + # render + rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + rgb = rgb[:, :, :3].astype(np.float32) + valid_mask = (depth > 0)[:, :, None] + + # save to image + img = rgb * valid_mask + img * (1 - valid_mask) + return img.astype(np.uint8) \ No newline at end of file diff --git a/data_processing/data/CrowdPose/CrowdPose.py b/data_processing/data/CrowdPose/CrowdPose.py new file mode 100644 index 0000000..c169603 --- /dev/null +++ b/data_processing/data/CrowdPose/CrowdPose.py @@ -0,0 +1,335 @@ +import os +import os.path as osp +import numpy as np +from config import cfg +import copy +import json +import scipy.io as sio +import cv2 +import random +import math +import torch +import transforms3d +from pycocotools.coco import COCO + +from utils.posefix import replace_joint_img +from utils.smpl import SMPL +from utils.preprocessing import load_img, process_bbox, augmentation, compute_iou, get_bbox +from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton +from utils.transforms import world2cam, cam2pixel, pixel2cam, transform_joint_to_other_db + + +class CrowdPose(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = data_split + self.img_path = osp.join(cfg.data_dir, 'CrowdPose', 'images') + self.annot_path = osp.join(cfg.data_dir, 'CrowdPose', 'annotations') + self.target_data_split = 'val' + self.fitting_thr = 5.0 # pixel in cfg.output_hm_shape space + + # mscoco skeleton + self.coco_joint_num = 18 # original: 17, manually added pelvis + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis') + self.coco_skeleton = ((1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12)) + self.coco_flip_pairs = ((1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16)) + self.coco_joint_regressor = np.load(osp.join(cfg.data_dir, 'MSCOCO', 'J_regressor_coco_hip_smpl.npy')) + + # crowdpose skeleton + self.crowdpose_jonit_num = 14+1 # manually added pelvis + self.crowdpose_joints_name = ('L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Head_top', 'Neck', 'Pelvis') + self.crowdpose_skeleton = ((0,2), (0,13), (1,3), (1,13), (2,4), (3,5), (6,14), (7,14), (6,8), (7,9), (8,10), (9,11), (12,13), (13,14) ) + self.crowdpose_flip_pairs = ((0, 1), (1, 2), (3, 4), (5, 6), (6, 7), (8, 9), (10, 11)) + self.crowdpose_coco_common_jidx = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14) # for posefix, exclude pelvis + + # smpl skeleton + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.flip_pairs = self.smpl.flip_pairs + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + self.datalist = self.load_data() + print("crowdpose data len: ", len(self.datalist)) + + def add_pelvis(self, joint_coord): + lhip_idx = self.crowdpose_joints_name.index('L_Hip') + rhip_idx = self.crowdpose_joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # joint_valid + pelvis = pelvis.reshape(1, 3) + joint_coord = np.concatenate((joint_coord, pelvis)) + return joint_coord + + def load_data(self): + datalist = [] + if self.data_split == 'train': + split_list = ['train'] if self.data_split == 'train' else [self.target_data_split] + + datalist = [] + for split in split_list: + db = COCO(osp.join(self.annot_path, f'crowdpose_{split}.json')) + # smpl parameter load + with open(osp.join(self.annot_path, f'CrowdPose_{split}_SMPL_NeuralAnnot.json'), 'r') as f: + smpl_params = json.load(f) + + for iid in db.imgs.keys(): + aids = db.getAnnIds([iid]) + + tmplist = [] + for aid in aids: + ann = db.anns[aid] + img = db.loadImgs(ann['image_id'])[0] + img_path = osp.join(self.img_path, img['file_name']) + # bbox + if split != 'val': # correct reversed img width,height info + width, height = img['height'], img['width'] + else: + width, height = img['width'], img['height'] + + if sum(ann['keypoints']) == 0: + continue + + # bbox + # tight_bbox = np.array(ann['bbox']) + # bbox = process_bbox(tight_bbox, width, height) + # if bbox is None: continue + + # joint coordinates + joint_img = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3) + joint_img = self.add_pelvis(joint_img) + joint_valid = (joint_img[:, 2].copy().reshape(-1, 1) > 0).astype(np.float32) + joint_img[:, 2] = joint_valid[:, 0] # for posefix, only good for 2d datasets + + # bbox + if cfg.use_bbox_in_ann: + tight_bbox = np.array(ann['bbox']) + else: + tight_bbox = get_bbox(joint_img, np.ones_like(joint_img[:, 0]), crop_bottom_body=True) + # bbox = process_bbox(tight_bbox, width, height) + # if bbox is None: continue + + if str(aid) in smpl_params: + smpl_param = smpl_params[str(aid)] + if smpl_param['fit_err'] < self.fitting_thr: + smpl_param = None + else: + smpl_param = None + + tmplist.append({'img_path': img_path, + 'img_shape': (height, width), + #'bbox': bbox, + 'tight_bbox': tight_bbox, 'joint_img': joint_img, 'joint_valid': joint_valid, 'neural_annot_result': smpl_param}) + + for i, person in enumerate(tmplist): + tight_bbox = person['tight_bbox'] + + # for swap + num_overlap = 0 + near_joints = [] + other_persons = tmplist[:i] + tmplist[i + 1:] + for other in other_persons: + other_bbox = other['tight_bbox'] + iou = compute_iou(tight_bbox[None, :], other_bbox[None, :]) + if iou < 0.1: + continue + num_overlap += 1 + other_joint = transform_joint_to_other_db(other['joint_img'], self.crowdpose_joints_name, self.coco_joints_name) + near_joints.append(other_joint) + + person['num_overlap'] = num_overlap + person['near_joints'] = near_joints + + datalist.extend(tmplist) + + return datalist + + def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape): + pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'] + smpl_pose = torch.FloatTensor(pose).view(1, -1); + smpl_shape = torch.FloatTensor(shape).view(1, -1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + smpl_trans = torch.FloatTensor(trans).view(1, -1) # translation vector + + # flip smpl pose parameter (axis-angle) + if do_flip: + smpl_pose = smpl_pose.view(-1, 3) + for pair in self.flip_pairs: + if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose): # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose. + smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone() + smpl_pose[:, 1:3] *= -1; # multiply -1 to y and z axis of axis-angle + smpl_pose = smpl_pose.view(1, -1) + + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer['neutral'](smpl_pose, smpl_shape, smpl_trans) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1, 3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + # flip translation + if do_flip: # avg of old and new root joint should be image center. + focal, princpt = cam_param['focal'], cam_param['princpt'] + flip_trans_x = 2 * (((img_shape[1] - 1) / 2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx, 2])) - 2 * smpl_joint_coord[self.root_joint_idx][0] + smpl_mesh_coord[:, 0] += flip_trans_x + smpl_joint_coord[:, 0] += flip_trans_x + + # change to mean shape if beta is too far from it + smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0. + + return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].numpy(), smpl_shape[0].numpy() + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path, img_shape, tight_bbox = data['img_path'], data['img_shape'], data['tight_bbox'] + # check if image is full body + # self.crowdpose_jonit_num = 14+1 # manually added pelvis + # self.crowdpose_joints_name = ('L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Head_top', 'Neck', 'Pelvis') + # data['joint_valid'].shape = (15, 1) + is_full_body = np.sum( + data['joint_valid'][6:12, :]) > 3 # 6:12 is the index of L_Hip, R_Hip, L_Knee, R_Knee, L_Ankle, R_Ankle + + # image load and affine transform + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, rot, do_flip,bbox = augmentation(img, tight_bbox, self.data_split,is_full_body = is_full_body) + img = self.transform(img.astype(np.float32)) / 255. + + if self.data_split == 'train': + # coco gt + crowdpose_joint_img = data['joint_img'] + crowdpose_joint_valid = data['joint_valid'] + if do_flip: + crowdpose_joint_img[:, 0] = img_shape[1] - 1 - crowdpose_joint_img[:, 0] + for pair in self.crowdpose_flip_pairs: + crowdpose_joint_img[pair[0], :], crowdpose_joint_img[pair[1], :] = crowdpose_joint_img[pair[1], :].copy(), crowdpose_joint_img[pair[0], :].copy() + crowdpose_joint_valid[pair[0], :], crowdpose_joint_valid[pair[1], :] = crowdpose_joint_valid[pair[1], :].copy(), crowdpose_joint_valid[pair[0], :].copy() + + crowdpose_joint_img_xy1 = np.concatenate((crowdpose_joint_img[:, :2], np.ones_like(crowdpose_joint_img[:, :1])), 1) + crowdpose_joint_img[:, :2] = np.dot(img2bb_trans, crowdpose_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + # for swap + if len(data['near_joints']) > 0: + near_joint_list = [] + for nj in data['near_joints']: + near_joint = np.ones((self.coco_joint_num, 3), dtype=np.float32) + nj_xy1 = np.concatenate((nj[:, :2], np.ones_like(nj[:, :1])), axis=1) + near_joint[:, :2] = np.dot(img2bb_trans, nj_xy1.transpose(1, 0)).transpose(1, 0) + near_joint_list.append(near_joint) + near_joints = np.asarray(near_joint_list, dtype=np.float32) + else: + near_joints = np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) + + input_crowdpose_joint_img = crowdpose_joint_img.copy() + crowdpose_joint_img[:, 0] = crowdpose_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + crowdpose_joint_img[:, 1] = crowdpose_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # check truncation + crowdpose_joint_trunc = crowdpose_joint_valid * ((crowdpose_joint_img[:, 0] >= 0) * (crowdpose_joint_img[:, 0] < cfg.output_hm_shape[2]) * (crowdpose_joint_img[:, 1] >= 0) * (crowdpose_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape( + -1, 1).astype(np.float32) + + # transform coco joints to target db joints + crowdpose_joint_img = transform_joint_to_other_db(crowdpose_joint_img, self.crowdpose_joints_name, self.joints_name) + crowdpose_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + crowdpose_joint_valid = transform_joint_to_other_db(crowdpose_joint_valid, self.crowdpose_joints_name, self.joints_name) + crowdpose_joint_trunc = transform_joint_to_other_db(crowdpose_joint_trunc, self.crowdpose_joints_name, self.joints_name) + + # apply PoseFix + tmp_joint_img = transform_joint_to_other_db(input_crowdpose_joint_img, self.crowdpose_joints_name, self.coco_joints_name) + tmp_joint_img = replace_joint_img(tmp_joint_img, data['tight_bbox'], near_joints, data['num_overlap'], img2bb_trans) + tmp_joint_img = transform_joint_to_other_db(tmp_joint_img, self.coco_joints_name, self.crowdpose_joints_name) + input_crowdpose_joint_img[self.crowdpose_coco_common_jidx, :2] = tmp_joint_img[self.crowdpose_coco_common_jidx, :2] + """ + # debug PoseFix result + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_crowdpose_joint_img.T, self.crowdpose_skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + # import pdb; pdb.set_trace() + """ + input_crowdpose_joint_img[:, 0] = input_crowdpose_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + input_crowdpose_joint_img[:, 1] = input_crowdpose_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + input_crowdpose_joint_img = transform_joint_to_other_db(input_crowdpose_joint_img, self.crowdpose_joints_name, self.joints_name) + + neural_annot_result = data['neural_annot_result'] + if neural_annot_result is not None: + # use fitted mesh + smpl_param, cam_param = neural_annot_result['smpl_param'], neural_annot_result['cam_param'] + smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, do_flip, img_shape) + smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam)) + smpl_coord_img = cam2pixel(smpl_coord_cam, cam_param['focal'], cam_param['princpt']) + + # x,y affine transform, root-relative depth + smpl_coord_img_xy1 = np.concatenate((smpl_coord_img[:, :2], np.ones_like(smpl_coord_img[:, 0:1])), 1) + smpl_coord_img[:, :2] = np.dot(img2bb_trans, smpl_coord_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2] + """ + # vis smpl + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), smpl_coord_img[6890:].T, self.skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + """ + smpl_coord_img[:, 2] = smpl_coord_img[:, 2] - smpl_coord_cam[self.vertex_num + self.root_joint_idx][2] + smpl_coord_img[:, 0] = smpl_coord_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + smpl_coord_img[:, 1] = smpl_coord_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + smpl_coord_img[:, 2] = (smpl_coord_img[:, 2] / (cfg.bbox_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] + + # check truncation + smpl_trunc = ( + (smpl_coord_img[:, 0] >= 0) * (smpl_coord_img[:, 0] < cfg.output_hm_shape[2]) * (smpl_coord_img[:, 1] >= 0) * (smpl_coord_img[:, 1] < cfg.output_hm_shape[1]) * (smpl_coord_img[:, 2] >= 0) * ( + smpl_coord_img[:, 2] < cfg.output_hm_shape[0])).reshape(-1, 1).astype(np.float32) + + # split mesh and joint coordinates + smpl_mesh_img = smpl_coord_img[:self.vertex_num]; + smpl_joint_img = smpl_coord_img[self.vertex_num:]; + smpl_mesh_trunc = smpl_trunc[:self.vertex_num]; + smpl_joint_trunc = smpl_trunc[self.vertex_num:]; + + # already checked in load_data() + is_valid_fit = True + + else: + smpl_joint_img = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + smpl_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + smpl_mesh_img = np.zeros((self.vertex_num, 3), dtype=np.float32) # dummy + smpl_pose = np.zeros((72), dtype=np.float32) # dummy + smpl_shape = np.zeros((10), dtype=np.float32) # dummy + smpl_joint_trunc = np.zeros((self.joint_num, 1), dtype=np.float32) + smpl_mesh_trunc = np.zeros((self.vertex_num, 1), dtype=np.float32) + is_valid_fit = False + + # 3D data rotation augmentation + rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]], dtype=np.float32) + # parameter + smpl_pose = smpl_pose.reshape(-1, 3) + root_pose = smpl_pose[self.root_joint_idx, :] + root_pose, _ = cv2.Rodrigues(root_pose) + root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat, root_pose)) + smpl_pose[self.root_joint_idx] = root_pose.reshape(3) + smpl_pose = smpl_pose.reshape(-1) + # smpl coordinate + smpl_joint_cam = smpl_joint_cam - smpl_joint_cam[self.root_joint_idx, None] # root-relative + smpl_joint_cam = np.dot(rot_aug_mat, smpl_joint_cam.transpose(1, 0)).transpose(1, 0) + + # SMPL pose parameter validity + smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32) + for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'): + smpl_param_valid[self.joints_name.index(name)] = 0 + smpl_param_valid = smpl_param_valid.reshape(-1) + + inputs = {'img': img, 'joints': input_crowdpose_joint_img[:, :2], 'joints_mask': crowdpose_joint_trunc} + targets = {'orig_joint_img': crowdpose_joint_img, 'fit_joint_img': smpl_joint_img, 'orig_joint_cam': crowdpose_joint_cam, 'fit_joint_cam': smpl_joint_cam, 'pose_param': smpl_pose, 'shape_param': smpl_shape} + meta_info = {'orig_joint_valid': crowdpose_joint_valid, 'orig_joint_trunc': crowdpose_joint_trunc, 'fit_param_valid': smpl_param_valid, 'fit_joint_trunc': smpl_joint_trunc, 'is_valid_fit': float(is_valid_fit),'bbox': bbox, + 'is_3D': float(False)} + return inputs, targets, meta_info + diff --git a/data_processing/data/Human36M/Human36M.py b/data_processing/data/Human36M/Human36M.py new file mode 100644 index 0000000..2f78ad2 --- /dev/null +++ b/data_processing/data/Human36M/Human36M.py @@ -0,0 +1,482 @@ +import os +import os.path as osp +import numpy as np +import torch +import cv2 +import random +import json +import math +import copy +import transforms3d +from pycocotools.coco import COCO +from config import cfg +from utils.posefix import replace_joint_img +from utils.smpl import SMPL +from utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation +from utils.transforms import world2cam, cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db +from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton + + +class Human36M(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = data_split + self.img_dir = osp.join(cfg.data_dir, 'Human36M', 'images') + self.annot_path = osp.join(cfg.data_dir, 'Human36M', 'annotations') + self.human_bbox_root_dir = osp.join(cfg.data_dir, 'Human36M', 'rootnet_output', 'bbox_root_human36m_output.json') + self.action_name = ['Directions', 'Discussion', 'Eating', 'Greeting', 'Phoning', 'Posing', 'Purchases', 'Sitting', 'SittingDown', 'Smoking', 'Photo', 'Waiting', 'Walking', 'WalkDog', 'WalkTogether'] + self.fitting_thr = 25 # milimeter + + # COCO joint set + self.coco_joint_num = 17 # original: 17 + self.coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', + 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle') + + # H36M joint set + self.h36m_joint_num = 17 + self.h36m_joints_name = ('Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') + self.h36m_flip_pairs = ( (1, 4), (2, 5), (3, 6), (14, 11), (15, 12), (16, 13) ) + self.h36m_skeleton = ( (0, 7), (7, 8), (8, 9), (9, 10), (8, 11), (11, 12), (12, 13), (8, 14), (14, 15), (15, 16), (0, 1), (1, 2), (2, 3), (0, 4), (4, 5), (5, 6) ) + self.h36m_root_joint_idx = self.h36m_joints_name.index('Pelvis') + self.h36m_eval_joint = (1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16) + self.h36m_joint_regressor = np.load(osp.join(cfg.data_dir, 'Human36M', 'J_regressor_h36m_correct.npy')) + self.h36m_coco_common_jidx = (1, 2, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16) # for posefix, exclude pelvis + + # SMPL joint set + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.flip_pairs = self.smpl.flip_pairs + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + self.datalist = self.load_data() + print("h36m data len: ", len(self.datalist)) + + def get_subsampling_ratio(self): + if self.data_split == 'train': + return 5 + elif self.data_split == 'test': + return 64 + else: + assert 0, print('Unknown subset') + + def get_subject(self): + if self.data_split == 'train': + subject = [1,5,6,7,8] + elif self.data_split == 'test': + subject = [9,11] + else: + assert 0, print("Unknown subset") + + return subject + + def load_data(self): + subject_list = self.get_subject() + sampling_ratio = self.get_subsampling_ratio() + + # aggregate annotations from each subject + db = COCO() + cameras = {} + joints = {} + smpl_params = {} + for subject in subject_list: + # data load + with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_data.json'),'r') as f: + annot = json.load(f) + if len(db.dataset) == 0: + for k,v in annot.items(): + db.dataset[k] = v + else: + for k,v in annot.items(): + db.dataset[k] += v + # camera load + with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_camera.json'),'r') as f: + cameras[str(subject)] = json.load(f) + # joint coordinate load + with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_joint_3d.json'),'r') as f: + joints[str(subject)] = json.load(f) + # smpl parameter load + with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_smpl_param.json'),'r') as f: + smpl_params[str(subject)] = json.load(f) + db.createIndex() + + if self.data_split == 'test' and not cfg.use_gt_info: + print("Get bounding box and root from " + self.human_bbox_root_dir) + bbox_root_result = {} + with open(self.human_bbox_root_dir) as f: + annot = json.load(f) + for i in range(len(annot)): + bbox_root_result[str(annot[i]['image_id'])] = {'bbox': np.array(annot[i]['bbox']), 'root': np.array(annot[i]['root_cam'])} + else: + print("Get bounding box and root from groundtruth") + + datalist = [] + for aid in db.anns.keys(): + ann = db.anns[aid] + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + img_path = osp.join(self.img_dir, img['file_name']) + img_shape = (img['height'], img['width']) + + # check subject and frame_idx + frame_idx = img['frame_idx']; + if frame_idx % sampling_ratio != 0: + continue + + # check smpl parameter exist + subject = img['subject']; action_idx = img['action_idx']; subaction_idx = img['subaction_idx']; frame_idx = img['frame_idx']; + try: + smpl_param = smpl_params[str(subject)][str(action_idx)][str(subaction_idx)][str(frame_idx)] + except KeyError: + smpl_param = None + + # camera parameter + cam_idx = img['cam_idx'] + cam_param = cameras[str(subject)][str(cam_idx)] + R,t,f,c = np.array(cam_param['R'], dtype=np.float32), np.array(cam_param['t'], dtype=np.float32), np.array(cam_param['f'], dtype=np.float32), np.array(cam_param['c'], dtype=np.float32) + cam_param = {'R': R, 't': t, 'focal': f, 'princpt': c} + + # only use frontal camera following previous works (HMR and SPIN) + if self.data_split == 'test' and str(cam_idx) != '4': + continue + + # project world coordinate to cam, image coordinate space + joint_world = np.array(joints[str(subject)][str(action_idx)][str(subaction_idx)][str(frame_idx)], dtype=np.float32) + joint_cam = world2cam(joint_world, R, t) + joint_img = cam2pixel(joint_cam, f, c) + joint_valid = np.ones((self.h36m_joint_num,1)) + + if cfg.use_bbox_in_ann: + tight_bbox = np.array(ann['bbox']) + else: + tight_bbox = get_bbox(joint_img, np.ones_like(joint_img[:, 0]), crop_bottom_body=True) + if self.data_split == 'test' and not cfg.use_gt_info: + bbox = bbox_root_result[str(image_id)]['bbox'] # bbox should be aspect ratio preserved-extended. It is done in RootNet. + root_joint_depth = bbox_root_result[str(image_id)]['root'][2] + datalist.append({ + 'img_path': img_path, + 'img_id': image_id, + 'img_shape': img_shape, + 'bbox': bbox, + 'tight_bbox': tight_bbox, + 'joint_img': joint_img, + 'joint_cam': joint_cam, + 'joint_valid': joint_valid, + 'smpl_param': smpl_param, + 'root_joint_depth': root_joint_depth, + 'cam_param': cam_param, + 'num_overlap': 0, + 'near_joints': np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) # coco_joint_num + + }) + else: + # bbox = process_bbox(np.array(ann['bbox']), img['width'], img['height']) + # if bbox is None: continue + root_joint_depth = joint_cam[self.h36m_root_joint_idx][2] + + datalist.append({ + 'img_path': img_path, + 'img_id': image_id, + 'img_shape': img_shape, + #'bbox': bbox, + 'tight_bbox': tight_bbox, + 'joint_img': joint_img, + 'joint_cam': joint_cam, + 'joint_valid': joint_valid, + 'smpl_param': smpl_param, + 'root_joint_depth': root_joint_depth, + 'cam_param': cam_param, + 'num_overlap': 0, + 'near_joints': np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) # coco_joint_num + + }) + + return datalist + + def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape): + pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'] + smpl_pose = torch.FloatTensor(pose).view(-1,3); smpl_shape = torch.FloatTensor(shape).view(1,-1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + R, t = np.array(cam_param['R'], dtype=np.float32).reshape(3,3), np.array(cam_param['t'], dtype=np.float32).reshape(3) # camera rotation and translation + + # merge root pose and camera rotation + root_pose = smpl_pose[self.root_joint_idx,:].numpy() + root_pose, _ = cv2.Rodrigues(root_pose) + root_pose, _ = cv2.Rodrigues(np.dot(R,root_pose)) + smpl_pose[self.root_joint_idx] = torch.from_numpy(root_pose).view(3) + + # flip smpl pose parameter (axis-angle) + if do_flip: + for pair in self.flip_pairs: + if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose): # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose. + smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone() + smpl_pose[:,1:3] *= -1; # multiply -1 to y and z axis of axis-angle + smpl_pose = smpl_pose.view(1,-1) + + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer['neutral'](smpl_pose, smpl_shape) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1,3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + # compenstate rotation (translation from origin to root joint was not cancled) + smpl_trans = np.array(trans, dtype=np.float32).reshape(3) # translation vector from smpl coordinate to h36m world coordinate + smpl_trans = np.dot(R, smpl_trans[:,None]).reshape(1,3) + t.reshape(1,3)/1000 + root_joint_coord = smpl_joint_coord[self.root_joint_idx].reshape(1,3) + smpl_trans = smpl_trans - root_joint_coord + np.dot(R, root_joint_coord.transpose(1,0)).transpose(1,0) + smpl_mesh_coord = smpl_mesh_coord + smpl_trans + smpl_joint_coord = smpl_joint_coord + smpl_trans + + # flip translation + if do_flip: # avg of old and new root joint should be image center. + focal, princpt = cam_param['focal'], cam_param['princpt'] + flip_trans_x = 2 * (((img_shape[1] - 1)/2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx,2] * 1000)) / 1000 - 2 * smpl_joint_coord[self.root_joint_idx][0] + smpl_mesh_coord[:,0] += flip_trans_x + smpl_joint_coord[:,0] += flip_trans_x + + # change to mean shape if beta is too far from it + smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0. + + # meter -> milimeter + smpl_mesh_coord *= 1000; smpl_joint_coord *= 1000; + return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].numpy(), smpl_shape[0].numpy() + + def get_fitting_error(self, h36m_joint, smpl_mesh, do_flip): + h36m_joint = h36m_joint - h36m_joint[self.h36m_root_joint_idx,None,:] # root-relative + if do_flip: + h36m_joint[:,0] = -h36m_joint[:,0] + for pair in self.h36m_flip_pairs: + h36m_joint[pair[0],:] , h36m_joint[pair[1],:] = h36m_joint[pair[1],:].copy(), h36m_joint[pair[0],:].copy() + + h36m_from_smpl = np.dot(self.h36m_joint_regressor, smpl_mesh) + h36m_from_smpl = h36m_from_smpl - np.mean(h36m_from_smpl,0)[None,:] + np.mean(h36m_joint,0)[None,:] # translation alignment + + error = np.sqrt(np.sum((h36m_joint - h36m_from_smpl)**2,1)).mean() + return error + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path, img_shape, tight_bbox, smpl_param, cam_param = data['img_path'], data['img_shape'], data['tight_bbox'], data['smpl_param'], data['cam_param'] + + # img + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, rot, do_flip,bbox = augmentation(img, tight_bbox, self.data_split,is_full_body = True ) # always full body + img = self.transform(img.astype(np.float32))/255. + + if self.data_split == 'train': + # h36m gt + h36m_joint_img = data['joint_img'] + h36m_joint_cam = data['joint_cam'] + h36m_joint_cam = h36m_joint_cam - h36m_joint_cam[self.h36m_root_joint_idx,None,:] # root-relative + h36m_joint_valid = data['joint_valid'] + if do_flip: + h36m_joint_cam[:,0] = -h36m_joint_cam[:,0] + h36m_joint_img[:,0] = img_shape[1] - 1 - h36m_joint_img[:,0] + for pair in self.h36m_flip_pairs: + h36m_joint_img[pair[0],:], h36m_joint_img[pair[1],:] = h36m_joint_img[pair[1],:].copy(), h36m_joint_img[pair[0],:].copy() + h36m_joint_cam[pair[0],:], h36m_joint_cam[pair[1],:] = h36m_joint_cam[pair[1],:].copy(), h36m_joint_cam[pair[0],:].copy() + h36m_joint_valid[pair[0],:], h36m_joint_valid[pair[1],:] = h36m_joint_valid[pair[1],:].copy(), h36m_joint_valid[pair[0],:].copy() + + h36m_joint_img_xy1 = np.concatenate((h36m_joint_img[:,:2], np.ones_like(h36m_joint_img[:,:1])),1) + h36m_joint_img[:,:2] = np.dot(img2bb_trans, h36m_joint_img_xy1.transpose(1,0)).transpose(1,0) + input_h36m_joint_img = h36m_joint_img.copy() + h36m_joint_img[:,0] = h36m_joint_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + h36m_joint_img[:,1] = h36m_joint_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + h36m_joint_img[:,2] = h36m_joint_img[:,2] - h36m_joint_img[self.h36m_root_joint_idx][2] # root-relative + h36m_joint_img[:,2] = (h36m_joint_img[:,2] / (cfg.bbox_3d_size * 1000 / 2) + 1)/2. * cfg.output_hm_shape[0] # change cfg.bbox_3d_size from meter to milimeter + + # check truncation + h36m_joint_trunc = h36m_joint_valid * ((h36m_joint_img[:,0] >= 0) * (h36m_joint_img[:,0] < cfg.output_hm_shape[2]) * \ + (h36m_joint_img[:,1] >= 0) * (h36m_joint_img[:,1] < cfg.output_hm_shape[1]) * \ + (h36m_joint_img[:,2] >= 0) * (h36m_joint_img[:,2] < cfg.output_hm_shape[0])).reshape(-1,1).astype(np.float32) + + """ + print(f'{img_path} trunc:\n', h36m_joint_trunc.nonzero()) + tmp_coord = h36m_joint_img[:, :2] * np.array([[cfg.input_img_shape[1] / cfg.output_hm_shape[2], cfg.input_img_shape[0]/ cfg.output_hm_shape[1]]]) + newimg = vis_keypoints(img.numpy().transpose(1,2,0), tmp_coord) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + """ + + # transform h36m joints to target db joints + h36m_joint_img = transform_joint_to_other_db(h36m_joint_img, self.h36m_joints_name, self.joints_name) + h36m_joint_cam = transform_joint_to_other_db(h36m_joint_cam, self.h36m_joints_name, self.joints_name) + h36m_joint_valid = transform_joint_to_other_db(h36m_joint_valid, self.h36m_joints_name, self.joints_name) + h36m_joint_trunc = transform_joint_to_other_db(h36m_joint_trunc, self.h36m_joints_name, self.joints_name) + + # apply PoseFix + input_h36m_joint_img[:, 2] = 1 # joint valid + tmp_joint_img = transform_joint_to_other_db(input_h36m_joint_img, self.h36m_joints_name, self.coco_joints_name) + tmp_joint_img = replace_joint_img(tmp_joint_img, data['tight_bbox'], data['near_joints'], data['num_overlap'], img2bb_trans) + tmp_joint_img = transform_joint_to_other_db(tmp_joint_img, self.coco_joints_name, self.h36m_joints_name) + input_h36m_joint_img[self.h36m_coco_common_jidx, :2] = tmp_joint_img[self.h36m_coco_common_jidx, :2] + """ + # debug PoseFix result + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_h36m_joint_img.T, self.h36m_skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + input_h36m_joint_img[:, 0] = input_h36m_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + input_h36m_joint_img[:, 1] = input_h36m_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + input_h36m_joint_img = transform_joint_to_other_db(input_h36m_joint_img, self.h36m_joints_name, self.joints_name) + joint_mask = h36m_joint_trunc + + if smpl_param is not None: + # smpl coordinates + smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, do_flip, img_shape) + smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam)) + focal, princpt = cam_param['focal'], cam_param['princpt'] + smpl_coord_img = cam2pixel(smpl_coord_cam, focal, princpt) + + """ + # vis smpl joint coord + tmpimg = cv2.imread(img_path) + newimg = vis_keypoints(tmpimg, smpl_coord_img[6890:]) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + + # affine transform x,y coordinates, root-relative depth + smpl_coord_img_xy1 = np.concatenate((smpl_coord_img[:,:2], np.ones_like(smpl_coord_img[:,:1])),1) + smpl_coord_img[:,:2] = np.dot(img2bb_trans, smpl_coord_img_xy1.transpose(1,0)).transpose(1,0)[:,:2] + smpl_coord_img[:,2] = smpl_coord_img[:,2] - smpl_coord_cam[self.vertex_num + self.root_joint_idx][2] + # coordinates voxelize + smpl_coord_img[:,0] = smpl_coord_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + smpl_coord_img[:,1] = smpl_coord_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + smpl_coord_img[:,2] = (smpl_coord_img[:,2] / (cfg.bbox_3d_size * 1000 / 2) + 1)/2. * cfg.output_hm_shape[0] # change cfg.bbox_3d_size from meter to milimeter + + # check truncation + smpl_trunc = ((smpl_coord_img[:,0] >= 0) * (smpl_coord_img[:,0] < cfg.output_hm_shape[2]) * \ + (smpl_coord_img[:,1] >= 0) * (smpl_coord_img[:,1] < cfg.output_hm_shape[1]) * \ + (smpl_coord_img[:,2] >= 0) * (smpl_coord_img[:,2] < cfg.output_hm_shape[0])).reshape(-1,1).astype(np.float32) + + # split mesh and joint coordinates + smpl_mesh_img = smpl_coord_img[:self.vertex_num]; smpl_joint_img = smpl_coord_img[self.vertex_num:]; + smpl_mesh_trunc = smpl_trunc[:self.vertex_num]; smpl_joint_trunc = smpl_trunc[self.vertex_num:]; + + # if fitted mesh is too far from h36m gt, discard it + is_valid_fit = True + error = self.get_fitting_error(data['joint_cam'], smpl_mesh_cam, do_flip) + if error > self.fitting_thr: + is_valid_fit = False + + else: + smpl_joint_img = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_joint_cam = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_mesh_img = np.zeros((self.vertex_num,3), dtype=np.float32) # dummy + smpl_pose = np.zeros((72), dtype=np.float32) # dummy + smpl_shape = np.zeros((10), dtype=np.float32) # dummy + smpl_joint_trunc = np.zeros((self.joint_num,1), dtype=np.float32) # dummy + smpl_mesh_trunc = np.zeros((self.vertex_num,1), dtype=np.float32) # dummy + is_valid_fit = False + + # 3D data rotation augmentation + rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]], dtype=np.float32) + # h36m coordinate + h36m_joint_cam = np.dot(rot_aug_mat, h36m_joint_cam.transpose(1,0)).transpose(1,0) / 1000 # milimeter to meter + # parameter + smpl_pose = smpl_pose.reshape(-1,3) + root_pose = smpl_pose[self.root_joint_idx,:] + root_pose, _ = cv2.Rodrigues(root_pose) + root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat,root_pose)) + smpl_pose[self.root_joint_idx] = root_pose.reshape(3) + smpl_pose = smpl_pose.reshape(-1) + # smpl coordinate + smpl_joint_cam = smpl_joint_cam - smpl_joint_cam[self.root_joint_idx,None] # root-relative + smpl_joint_cam = np.dot(rot_aug_mat, smpl_joint_cam.transpose(1,0)).transpose(1,0) / 1000 # milimeter to meter + + # SMPL pose parameter validity + smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32) + for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'): + smpl_param_valid[self.joints_name.index(name)] = 0 + smpl_param_valid = smpl_param_valid.reshape(-1) + + inputs = {'img': img, 'joints': input_h36m_joint_img[:, :2], 'joints_mask': joint_mask} + targets = {'orig_joint_img': h36m_joint_img, 'fit_joint_img': smpl_joint_img, 'orig_joint_cam': h36m_joint_cam, 'fit_joint_cam': smpl_joint_cam, 'pose_param': smpl_pose, 'shape_param': smpl_shape} + meta_info = {'orig_joint_valid': h36m_joint_valid, 'orig_joint_trunc': h36m_joint_trunc, 'fit_param_valid': smpl_param_valid, 'fit_joint_trunc': smpl_joint_trunc, 'is_valid_fit': float(is_valid_fit), 'bbox': bbox, + 'is_3D': float(True)} + return inputs, targets, meta_info + else: + inputs = {'img': img} + targets = {} + meta_info = {'bb2img_trans': bb2img_trans} + return inputs, targets, meta_info + + def evaluate(self, outs, cur_sample_idx): + + annots = self.datalist + sample_num = len(outs) + eval_result = {'mpjpe_lixel': [], 'pa_mpjpe_lixel': [], 'mpjpe_param': [], 'pa_mpjpe_param': []} + for n in range(sample_num): + annot = annots[cur_sample_idx + n] + out = outs[n] + + # mesh from lixel + # x,y: resize to input image space and perform bbox to image affine transform + mesh_out_img = out['mesh_coord_img'] + mesh_out_img[:,0] = mesh_out_img[:,0] / cfg.output_hm_shape[2] * cfg.input_img_shape[1] + mesh_out_img[:,1] = mesh_out_img[:,1] / cfg.output_hm_shape[1] * cfg.input_img_shape[0] + mesh_out_img_xy1 = np.concatenate((mesh_out_img[:,:2], np.ones_like(mesh_out_img[:,:1])),1) + mesh_out_img[:,:2] = np.dot(out['bb2img_trans'], mesh_out_img_xy1.transpose(1,0)).transpose(1,0)[:,:2] + # z: devoxelize and translate to absolute depth + root_joint_depth = annot['root_joint_depth'] + mesh_out_img[:,2] = (mesh_out_img[:,2] / cfg.output_hm_shape[0] * 2. - 1) * (cfg.bbox_3d_size * 1000 / 2) + mesh_out_img[:,2] = mesh_out_img[:,2] + root_joint_depth + # camera back-projection + cam_param = annot['cam_param'] + focal, princpt = cam_param['focal'], cam_param['princpt'] + mesh_out_cam = pixel2cam(mesh_out_img, focal, princpt) + + # h36m joint from gt mesh + pose_coord_gt_h36m = annot['joint_cam'] + pose_coord_gt_h36m = pose_coord_gt_h36m - pose_coord_gt_h36m[self.h36m_root_joint_idx,None] # root-relative + pose_coord_gt_h36m = pose_coord_gt_h36m[self.h36m_eval_joint,:] + + # h36m joint from lixel mesh + pose_coord_out_h36m = np.dot(self.h36m_joint_regressor, mesh_out_cam) + pose_coord_out_h36m = pose_coord_out_h36m - pose_coord_out_h36m[self.h36m_root_joint_idx,None] # root-relative + pose_coord_out_h36m = pose_coord_out_h36m[self.h36m_eval_joint,:] + pose_coord_out_h36m_aligned = rigid_align(pose_coord_out_h36m, pose_coord_gt_h36m) + eval_result['mpjpe_lixel'].append(np.sqrt(np.sum((pose_coord_out_h36m - pose_coord_gt_h36m)**2,1)).mean()) + eval_result['pa_mpjpe_lixel'].append(np.sqrt(np.sum((pose_coord_out_h36m_aligned - pose_coord_gt_h36m)**2,1)).mean()) + + vis = False + if vis: + filename = annot['img_path'].split('/')[-1][:-4] + + img = load_img(annot['img_path'])[:,:,::-1] + img = vis_mesh(img, mesh_out_img, 0.5) + cv2.imwrite(filename + '.jpg', img) + + save_obj(mesh_out_cam, self.smpl.face, filename + '.obj') + + return eval_result + + def print_eval_result(self, eval_result): + print('MPJPE from lixel mesh: %.2f mm' % np.mean(eval_result['mpjpe_lixel'])) + print('PA MPJPE from lixel mesh: %.2f mm' % np.mean(eval_result['pa_mpjpe_lixel'])) + + print('MPJPE from param mesh: %.2f mm' % np.mean(eval_result['mpjpe_param'])) + print('PA MPJPE from param mesh: %.2f mm' % np.mean(eval_result['pa_mpjpe_param'])) diff --git a/data_processing/data/MPII/MPII.py b/data_processing/data/MPII/MPII.py new file mode 100644 index 0000000..4441017 --- /dev/null +++ b/data_processing/data/MPII/MPII.py @@ -0,0 +1,295 @@ +import os +import os.path as osp +import numpy as np +from config import cfg +import copy +import json +import cv2 +import torch +from pycocotools.coco import COCO + +from utils.posefix import replace_joint_img +from utils.preprocessing import compute_iou, process_bbox, load_img, augmentation,get_bbox +from utils.smpl import SMPL +from utils.transforms import transform_joint_to_other_db, cam2pixel +from utils.vis import vis_keypoints_with_skeleton + + +class MPII(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = data_split + self.img_path = osp.join(cfg.data_dir, 'MPII', 'data') + self.annot_path = osp.join(cfg.data_dir, 'MPII', 'data', 'annotations') + + # mpii skeleton + self.mpii_joint_num = 16 + self.mpii_joints_name = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Thorax', 'Neck', 'Head_top', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', 'L_Elbow', 'L_Wrist') + self.mpii_flip_pairs = ((0, 5), (1, 4), (2, 3), (10, 15), (11, 14), (12, 13)) + self.mpii_skeleton = ((0,1), (1,2), (2,6), (3,6), (3, 4), (4, 5), (6, 7), (7, 8), (8, 9), (10, 11), (11, 12) , (7, 12), (7, 13), (13, 14), (14, 15)) + + # smpl skeleton + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.flip_pairs = self.smpl.flip_pairs + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + self.coco_joint_num = 18 # original: 17, manually added pelvis + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis') + self.mpii_coco_common_idx = (0, 1, 2, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15) + + self.datalist = self.load_data() + print("mpii data len: ", len(self.datalist)) + + def load_data(self): + db = COCO(osp.join(self.annot_path, 'train.json')) + with open(osp.join(self.annot_path, 'MPII_train_SMPL_NeuralAnnot.json')) as f: + smpl_params = json.load(f) + + datalist = [] + for iid in db.imgs.keys(): + aids = db.getAnnIds([iid]) + + tmplist = [] + for aid in aids: + ann = db.anns[aid] + img = db.loadImgs(ann['image_id'])[0] + img_path = osp.join(self.img_path, img['file_name']) + width, height = img['width'], img['height'] + + # bbox + # tight_bbox = np.array(ann['bbox']) + # bbox = process_bbox(tight_bbox, width, height) + # if bbox is None: continue + + # joint coordinates + joint_img = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3) + joint_valid = (joint_img[:, 2].copy().reshape(-1, 1) > 0).astype(np.float32) + joint_img[:, 2] = joint_valid[:, 0] # for posefix, only good for 2d datasets + + # bbox + if cfg.use_bbox_in_ann: + tight_bbox = np.array(ann['bbox']) + else: + tight_bbox = get_bbox(joint_img, np.ones_like(joint_img[:, 0]), crop_bottom_body=True) + + # smpl parameter + if str(aid) in smpl_params: + smpl_param = smpl_params[str(aid)] + else: + smpl_param = None + + tmplist.append({'img_path': img_path, 'img_shape': (height, width), + #'bbox': bbox, + 'tight_bbox': tight_bbox, 'joint_img': joint_img, 'joint_valid': joint_valid, 'smpl_param': smpl_param}) + + for i, person in enumerate(tmplist): + tight_bbox = person['tight_bbox'] + + # for swap + num_overlap = 0 + near_joints = [] + other_persons = tmplist[:i] + tmplist[i + 1:] + for other in other_persons: + other_bbox = other['tight_bbox'] + iou = compute_iou(tight_bbox[None, :], other_bbox[None, :]) + if iou < 0.1: + continue + num_overlap += 1 + other_joint = transform_joint_to_other_db(other['joint_img'], self.mpii_joints_name, self.coco_joints_name) + near_joints.append(other_joint) + + person['num_overlap'] = num_overlap + person['near_joints'] = near_joints + + datalist.extend(tmplist) + """ + if num_overlap > 2: + tmpimg = cv2.imread(img_path) + newimg = vis_keypoints_with_skeleton(tmpimg, joint_img.T, self.mpii_skeleton) + cv2.imshow(f'{img_path}', newimg / 255) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; + pdb.set_trace() + """ + + return datalist + + def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape): + pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'] + smpl_pose = torch.FloatTensor(pose).view(1, -1); + smpl_shape = torch.FloatTensor(shape).view(1, -1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + smpl_trans = torch.FloatTensor(trans).view(1, -1) # translation vector + + # flip smpl pose parameter (axis-angle) + if do_flip: + smpl_pose = smpl_pose.view(-1, 3) + for pair in self.flip_pairs: + if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose): # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose. + smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone() + smpl_pose[:, 1:3] *= -1; # multiply -1 to y and z axis of axis-angle + smpl_pose = smpl_pose.view(1, -1) + + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer['neutral'](smpl_pose, smpl_shape, smpl_trans) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1, 3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + # flip translation + if do_flip: # avg of old and new root joint should be image center. + focal, princpt = cam_param['focal'], cam_param['princpt'] + flip_trans_x = 2 * (((img_shape[1] - 1) / 2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx, 2])) - 2 * smpl_joint_coord[self.root_joint_idx][0] + smpl_mesh_coord[:, 0] += flip_trans_x + smpl_joint_coord[:, 0] += flip_trans_x + + # change to mean shape if beta is too far from it + smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0. + + return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].numpy(), smpl_shape[0].numpy() + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path, img_shape, tight_bbox = data['img_path'], data['img_shape'], data['tight_bbox'] + # check if image is full body + # self.mpii_joint_num = 16 + # self.mpii_joints_name = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Thorax', 'Neck', 'Head_top', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', 'L_Elbow', 'L_Wrist') + # data['joint_valid'].shape = (15, 1) + is_full_body = np.sum( + data['joint_valid'][0:6, :]) > 3 # 0:6 is the index of R_Ankle, R_Knee, R_Hip, L_Hip, L_Knee, L_Ankle + + # image load and affine transform + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, rot, do_flip,bbox = augmentation(img, tight_bbox, self.data_split, is_full_body= is_full_body) + img = self.transform(img.astype(np.float32)) / 255. + + # mpii gt + mpii_joint_img = data['joint_img'] + mpii_joint_valid = data['joint_valid'] + if do_flip: + mpii_joint_img[:, 0] = img_shape[1] - 1 - mpii_joint_img[:, 0] + for pair in self.mpii_flip_pairs: + mpii_joint_img[pair[0], :], mpii_joint_img[pair[1], :] = mpii_joint_img[pair[1], :].copy(), mpii_joint_img[pair[0], :].copy() + mpii_joint_valid[pair[0], :], mpii_joint_valid[pair[1], :] = mpii_joint_valid[pair[1], :].copy(), mpii_joint_valid[pair[0], :].copy() + + mpii_joint_img_xy1 = np.concatenate((mpii_joint_img[:, :2], np.ones_like(mpii_joint_img[:, :1])), 1) + mpii_joint_img[:, :2] = np.dot(img2bb_trans, mpii_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + # for swap + if len(data['near_joints']) > 0: + near_joint_list = [] + for nj in data['near_joints']: + near_joint = np.ones((self.coco_joint_num, 3), dtype=np.float32) + nj_xy1 = np.concatenate((nj[:, :2], np.ones_like(nj[:, :1])), axis=1) + near_joint[:, :2] = np.dot(img2bb_trans, nj_xy1.transpose(1, 0)).transpose(1, 0) + near_joint_list.append(near_joint) + near_joints = np.asarray(near_joint_list, dtype=np.float32) + else: + near_joints = np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) + + input_mpii_joint_img = mpii_joint_img.copy() + mpii_joint_img[:, 0] = mpii_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + mpii_joint_img[:, 1] = mpii_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # check truncation + mpii_joint_trunc = mpii_joint_valid * ( + (mpii_joint_img[:, 0] >= 0) * (mpii_joint_img[:, 0] < cfg.output_hm_shape[2]) * (mpii_joint_img[:, 1] >= 0) * (mpii_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape(-1, + 1).astype(np.float32) + + # transform coco joints to target db joints + mpii_joint_img = transform_joint_to_other_db(mpii_joint_img, self.mpii_joints_name, self.joints_name) + mpii_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + mpii_joint_valid = transform_joint_to_other_db(mpii_joint_valid, self.mpii_joints_name, self.joints_name) + mpii_joint_trunc = transform_joint_to_other_db(mpii_joint_trunc, self.mpii_joints_name, self.joints_name) + + # apply PoseFix + tmp_joint_img = transform_joint_to_other_db(input_mpii_joint_img, self.mpii_joints_name, self.coco_joints_name) + tmp_joint_img = replace_joint_img(tmp_joint_img, data['tight_bbox'], near_joints, data['num_overlap'], img2bb_trans) + tmp_joint_img = transform_joint_to_other_db(tmp_joint_img, self.coco_joints_name, self.mpii_joints_name) + input_mpii_joint_img[self.mpii_coco_common_idx, :2] = tmp_joint_img[self.mpii_coco_common_idx, :2] + """ + # debug PoseFix result + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_mpii_joint_img.T, self.mpii_skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + input_mpii_joint_img[:, 0] = input_mpii_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + input_mpii_joint_img[:, 1] = input_mpii_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + input_mpii_joint_img = transform_joint_to_other_db(input_mpii_joint_img, self.mpii_joints_name, self.joints_name) + + smpl_param = data['smpl_param'] + if smpl_param is not None: + # use fitted mesh + smpl_param, cam_param = smpl_param['smpl_param'], smpl_param['cam_param'] + smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, do_flip, img_shape) + smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam)) + smpl_coord_img = cam2pixel(smpl_coord_cam, cam_param['focal'], cam_param['princpt']) + + # x,y affine transform, root-relative depth + smpl_coord_img_xy1 = np.concatenate((smpl_coord_img[:, :2], np.ones_like(smpl_coord_img[:, 0:1])), 1) + smpl_coord_img[:, :2] = np.dot(img2bb_trans, smpl_coord_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2] + smpl_coord_img[:, 2] = smpl_coord_img[:, 2] - smpl_coord_cam[self.vertex_num + self.root_joint_idx][2] + smpl_coord_img[:, 0] = smpl_coord_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + smpl_coord_img[:, 1] = smpl_coord_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + smpl_coord_img[:, 2] = (smpl_coord_img[:, 2] / (cfg.bbox_3d_size / 2) + 1) / 2. * cfg.output_hm_shape[0] + + # check truncation + smpl_trunc = ((smpl_coord_img[:, 0] >= 0) * (smpl_coord_img[:, 0] < cfg.output_hm_shape[2]) * (smpl_coord_img[:, 1] >= 0) * (smpl_coord_img[:, 1] < cfg.output_hm_shape[1]) * (smpl_coord_img[:, 2] >= 0) * ( + smpl_coord_img[:, 2] < cfg.output_hm_shape[0])).reshape(-1, 1).astype(np.float32) + + # split mesh and joint coordinates + smpl_joint_img = smpl_coord_img[self.vertex_num:]; + smpl_joint_trunc = smpl_trunc[self.vertex_num:]; + + """ + # vis smpl joint coord + # tmpimg = cv2.imread(img_path) + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), (smpl_joint_img.T)*4, self.skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + + # if fitted mesh is too far from h36m gt, discard it + is_valid_fit = True + else: + smpl_joint_img = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + smpl_joint_cam = np.zeros((self.joint_num, 3), dtype=np.float32) # dummy + smpl_pose = np.zeros((72), dtype=np.float32) # dummy + smpl_shape = np.zeros((10), dtype=np.float32) # dummy + smpl_joint_trunc = np.zeros((self.joint_num, 1), dtype=np.float32) + is_valid_fit = False + + # SMPL pose parameter validity + smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32) + for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'): + smpl_param_valid[self.joints_name.index(name)] = 0 + smpl_param_valid = smpl_param_valid.reshape(-1) + + inputs = {'img': img, 'joints': input_mpii_joint_img[:, :2], 'joints_mask': mpii_joint_trunc} + targets = {'orig_joint_img': mpii_joint_img, 'fit_joint_img': smpl_joint_img, 'orig_joint_cam': mpii_joint_cam, 'fit_joint_cam': smpl_joint_cam, 'pose_param': smpl_pose, 'shape_param': smpl_shape} + meta_info = {'orig_joint_valid': mpii_joint_valid, 'orig_joint_trunc': mpii_joint_trunc, 'fit_param_valid': smpl_param_valid, 'fit_joint_trunc': smpl_joint_trunc,'bbox': bbox, + 'is_valid_fit': float(is_valid_fit), 'is_3D': float(False)} + + return inputs, targets, meta_info + + diff --git a/data_processing/data/MSCOCO/MSCOCO.py b/data_processing/data/MSCOCO/MSCOCO.py new file mode 100644 index 0000000..7b7c42e --- /dev/null +++ b/data_processing/data/MSCOCO/MSCOCO.py @@ -0,0 +1,415 @@ +import os +import os.path as osp +import numpy as np +from config import cfg +import copy +import json +import scipy.io as sio +import cv2 +import random +import math +import torch +import transforms3d +from pycocotools.coco import COCO + +from utils.posefix import replace_joint_img +from utils.smpl import SMPL +from utils.preprocessing import load_img, process_bbox, augmentation, compute_iou,get_bbox +from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton, vis_bbox, render_mesh +from utils.transforms import world2cam, cam2pixel, pixel2cam, transform_joint_to_other_db + + + + +class MSCOCO(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = 'train' if data_split == 'train' else 'val' + self.img_path = osp.join(cfg.data_dir, 'MSCOCO', 'images') + self.annot_path = osp.join(cfg.data_dir, 'MSCOCO', 'annotations') + self.rootnet_output_path = osp.join(cfg.data_dir, 'MSCOCO', 'rootnet_output', 'bbox_root_coco_output.json') + self.fitting_thr = 3.0 # pixel in cfg.output_hm_shape space + + # mscoco skeleton + self.coco_joint_num = 18 # original: 17, manually added pelvis + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis') + self.coco_skeleton = ( (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12) ) + self.coco_flip_pairs = ( (1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16) ) + self.coco_joint_regressor = np.load(osp.join(cfg.data_dir, 'MSCOCO', 'J_regressor_coco_hip_smpl.npy')) + + # smpl skeleton + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.flip_pairs = self.smpl.flip_pairs + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + self.datalist = self.load_data() + print("coco data len: ", len(self.datalist)) + + def add_pelvis(self, joint_coord): + lhip_idx = self.coco_joints_name.index('L_Hip') + rhip_idx = self.coco_joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx,2] * joint_coord[rhip_idx,2] # joint_valid + pelvis = pelvis.reshape(1, 3) + joint_coord = np.concatenate((joint_coord, pelvis)) + return joint_coord + + def load_data(self): + db = COCO(osp.join(self.annot_path, 'person_keypoints_' + self.data_split + '2017.json')) + with open(osp.join(self.annot_path, 'MSCOCO_train_SMPL_NeuralAnnot.json')) as f: + # with open(osp.join(self.annot_path, 'coco_smplifyx_train.json')) as f: + smpl_params = json.load(f) + + datalist = [] + if self.data_split == 'train': + for iid in db.imgs.keys(): + aids = db.getAnnIds([iid]) + + tmplist = [] + for aid in aids: + ann = db.anns[aid] + img = db.loadImgs(ann['image_id'])[0] + imgname = osp.join('train2017', img['file_name']) + img_path = osp.join(self.img_path, imgname) + width, height = img['width'], img['height'] + + if ann['iscrowd'] or (ann['num_keypoints'] == 0): + continue + + # bbox + # tight_bbox = np.array(ann['bbox']) + # bbox = process_bbox(tight_bbox, width, height) + # if bbox is None: continue + + # joint coordinates + joint_img = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3) + joint_img = self.add_pelvis(joint_img) + joint_valid = (joint_img[:, 2].copy().reshape(-1, 1) > 0).astype(np.float32) + joint_img[:, 2] = joint_valid[:, 0] # for posefix, only good for 2d datasets + + # bbox + if cfg.use_bbox_in_ann: + tight_bbox = np.array(ann['bbox']) + else: + tight_bbox = get_bbox(joint_img, np.ones_like(joint_img[:, 0]), crop_bottom_body=True) + # bbox = process_bbox(tight_bbox, width, height) + # if bbox is None: continue + + + if str(aid) in smpl_params: + smpl_param = smpl_params[str(aid)] + else: + smpl_param = None + + tmplist.append({ + 'img_path': img_path, + 'img_shape': (height, width), + #'bbox': bbox, + 'tight_bbox': tight_bbox, + 'joint_img': joint_img, + 'joint_valid': joint_valid, + 'smpl_param': smpl_param + }) + + for i, person in enumerate(tmplist): + tight_bbox = person['tight_bbox'] + + # for swap + num_overlap = 0 + near_joints = [] + other_persons = tmplist[:i] + tmplist[i + 1:] + for other in other_persons: + other_tight_bbox = other['tight_bbox'] + iou = compute_iou(tight_bbox[None, :], other_tight_bbox[None, :]) + if iou < 0.1: + continue + num_overlap += 1 + near_joints.append(other['joint_img']) + + person['num_overlap'] = num_overlap + person['near_joints'] = near_joints + + datalist.extend(tmplist) + + else: + for aid in db.anns.keys(): + ann = db.anns[aid] + img = db.loadImgs(ann['image_id'])[0] + imgname = osp.join('val2017', img['file_name']) + img_path = osp.join(self.img_path, imgname) + width, height = img['width'], img['height'] + + if ann['iscrowd'] or (ann['num_keypoints'] == 0): + continue + + # bbox + tight_bbox = np.array(ann['bbox']) + bbox = process_bbox(tight_bbox, width, height) + if bbox is None: continue + + # joint coordinates + joint_img = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3) + joint_img = self.add_pelvis(joint_img) + joint_valid = (joint_img[:, 2].copy().reshape(-1, 1) > 0).astype(np.float32) + joint_img[:, 2] = joint_valid[:, 0] # for posefix, only good for 2d datasets + + smpl_param = None + + datalist.append({'img_path': img_path, 'img_shape': (height, width), 'bbox': bbox, 'tight_bbox': tight_bbox, 'joint_img': joint_img, 'joint_valid': joint_valid, 'smpl_param': smpl_param}) + + if len(datalist) > 100: + break + + return datalist + + def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape): + pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'] + smpl_pose = torch.FloatTensor(pose).view(1,-1); smpl_shape = torch.FloatTensor(shape).view(1,-1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + smpl_trans = torch.FloatTensor(trans).view(1,-1) # translation vector + + # flip smpl pose parameter (axis-angle) + if do_flip: + smpl_pose = smpl_pose.view(-1,3) + for pair in self.flip_pairs: + if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose): # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose. + smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone() + smpl_pose[:,1:3] *= -1; # multiply -1 to y and z axis of axis-angle + smpl_pose = smpl_pose.view(1,-1) + + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer['neutral'](smpl_pose, smpl_shape, smpl_trans) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1,3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + # flip translation + if do_flip: # avg of old and new root joint should be image center. + focal, princpt = cam_param['focal'], cam_param['princpt'] + flip_trans_x = 2 * (((img_shape[1] - 1)/2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx,2])) - 2 * smpl_joint_coord[self.root_joint_idx][0] + smpl_mesh_coord[:,0] += flip_trans_x + smpl_joint_coord[:,0] += flip_trans_x + + # change to mean shape if beta is too far from it + smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0. + + return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].numpy(), smpl_shape[0].numpy() + + def get_fitting_error(self, coco_joint, smpl_mesh, cam_param, img2bb_trans, coco_joint_valid): + # get coco joint from smpl mesh + coco_from_smpl = np.dot(self.coco_joint_regressor, smpl_mesh) + coco_from_smpl = self.add_pelvis(coco_from_smpl) # z-axis component will be removed + coco_from_smpl = cam2pixel(coco_from_smpl, cam_param['focal'], cam_param['princpt']) + coco_from_smpl_xy1 = np.concatenate((coco_from_smpl[:,:2], np.ones_like(coco_from_smpl[:,0:1])),1) + coco_from_smpl[:,:2] = np.dot(img2bb_trans, coco_from_smpl_xy1.transpose(1,0)).transpose(1,0) + coco_from_smpl[:,0] = coco_from_smpl[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_from_smpl[:,1] = coco_from_smpl[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # mask joint coordinates + coco_joint = coco_joint[:,:2][np.tile(coco_joint_valid,(1,2))==1].reshape(-1,2) + coco_from_smpl = coco_from_smpl[:,:2][np.tile(coco_joint_valid,(1,2))==1].reshape(-1,2) + + error = np.sqrt(np.sum((coco_joint - coco_from_smpl)**2,1)).mean() + return error + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path, img_shape, tight_bbox = data['img_path'], data['img_shape'], data['tight_bbox'] + # check if image is full body + # self.coco_joint_num = 18 # original: 17, manually added pelvis + # self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis') + # data['joint_valid'].shape = (18, 1) + is_full_body = np.sum( + data['joint_valid'][11:17, :]) > 3 # 11:17 = L_Hip, R_Hip, L_Knee, R_Knee, L_Ankle, R_Ankle + + # image load and affine transform + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, rot, do_flip,bbox = augmentation(img, tight_bbox, self.data_split,is_full_body = is_full_body) + img = self.transform(img.astype(np.float32))/255. + + if self.data_split == 'train': + # coco gt + coco_joint_img = data['joint_img'] + coco_joint_valid = data['joint_valid'] + if do_flip: + coco_joint_img[:,0] = img_shape[1] - 1 - coco_joint_img[:,0] + for pair in self.coco_flip_pairs: + coco_joint_img[pair[0],:], coco_joint_img[pair[1],:] = coco_joint_img[pair[1],:].copy(), coco_joint_img[pair[0],:].copy() + coco_joint_valid[pair[0],:], coco_joint_valid[pair[1],:] = coco_joint_valid[pair[1],:].copy(), coco_joint_valid[pair[0],:].copy() + + coco_joint_img_xy1 = np.concatenate((coco_joint_img[:,:2], np.ones_like(coco_joint_img[:,:1])),1) + coco_joint_img[:,:2] = np.dot(img2bb_trans, coco_joint_img_xy1.transpose(1,0)).transpose(1,0) + # for swap + if len(data['near_joints']) > 0: + near_joint_list = [] + for nj in data['near_joints']: + near_joint = np.ones((self.coco_joint_num, 3), dtype=np.float32) + nj_xy1 = np.concatenate((nj[:, :2], np.ones_like(nj[:, :1])), axis=1) + near_joint[:, :2] = np.dot(img2bb_trans, nj_xy1.transpose(1, 0)).transpose(1, 0) + near_joint_list.append(near_joint) + near_joints = np.asarray(near_joint_list, dtype=np.float32) + else: + near_joints = np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) + + input_coco_joint_img = coco_joint_img.copy() + coco_joint_img[:,0] = coco_joint_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_joint_img[:,1] = coco_joint_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # backup for calculating fitting error + _coco_joint_img = coco_joint_img.copy() + _coco_joint_valid = coco_joint_valid.copy() + + # check truncation + coco_joint_trunc = coco_joint_valid * ((coco_joint_img[:,0] >= 0) * (coco_joint_img[:,0] < cfg.output_hm_shape[2]) * \ + (coco_joint_img[:,1] >= 0) * (coco_joint_img[:,1] < cfg.output_hm_shape[1])).reshape(-1,1).astype(np.float32) + + # transform coco joints to target db joints + coco_joint_img = transform_joint_to_other_db(coco_joint_img, self.coco_joints_name, self.joints_name) + coco_joint_cam = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + coco_joint_valid = transform_joint_to_other_db(coco_joint_valid, self.coco_joints_name, self.joints_name) + coco_joint_trunc = transform_joint_to_other_db(coco_joint_trunc, self.coco_joints_name, self.joints_name) + + # apply PoseFix + input_coco_joint_img = replace_joint_img(input_coco_joint_img, data['tight_bbox'], near_joints, data['num_overlap'], img2bb_trans) + """ + # debug PoseFix result + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_coco_joint_img.T, self.coco_skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + input_coco_joint_img[:, 0] = input_coco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + input_coco_joint_img[:,1] = input_coco_joint_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + input_coco_joint_img = transform_joint_to_other_db(input_coco_joint_img, self.coco_joints_name, self.joints_name) + joint_mask = coco_joint_trunc + + smpl_param = data['smpl_param'] + if smpl_param is not None: + # use fitted mesh + smpl_param, cam_param = smpl_param['smpl_param'], smpl_param['cam_param'] + smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, do_flip, img_shape) + smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam)) + smpl_coord_img = cam2pixel(smpl_coord_cam, cam_param['focal'], cam_param['princpt']) + + # x,y affine transform, root-relative depth + smpl_coord_img_xy1 = np.concatenate((smpl_coord_img[:,:2], np.ones_like(smpl_coord_img[:,0:1])),1) + smpl_coord_img[:,:2] = np.dot(img2bb_trans, smpl_coord_img_xy1.transpose(1,0)).transpose(1,0)[:,:2] + smpl_coord_img[:,2] = smpl_coord_img[:,2] - smpl_coord_cam[self.vertex_num + self.root_joint_idx][2] + smpl_coord_img[:,0] = smpl_coord_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + smpl_coord_img[:,1] = smpl_coord_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + smpl_coord_img[:,2] = (smpl_coord_img[:,2] / (cfg.bbox_3d_size / 2) + 1)/2. * cfg.output_hm_shape[0] + + # check truncation + smpl_trunc = ((smpl_coord_img[:,0] >= 0) * (smpl_coord_img[:,0] < cfg.output_hm_shape[2]) * \ + (smpl_coord_img[:,1] >= 0) * (smpl_coord_img[:,1] < cfg.output_hm_shape[1]) * \ + (smpl_coord_img[:,2] >= 0) * (smpl_coord_img[:,2] < cfg.output_hm_shape[0])).reshape(-1,1).astype(np.float32) + + # split mesh and joint coordinates + smpl_mesh_img = smpl_coord_img[:self.vertex_num]; smpl_joint_img = smpl_coord_img[self.vertex_num:]; + smpl_mesh_trunc = smpl_trunc[:self.vertex_num]; smpl_joint_trunc = smpl_trunc[self.vertex_num:]; + + # if fitted mesh is too far from h36m gt, discard it + is_valid_fit = True + # error = self.get_fitting_error(_coco_joint_img, smpl_mesh_cam, cam_param, img2bb_trans, _coco_joint_valid) + # if error > self.fitting_thr: + # is_valid_fit = False + + else: + smpl_joint_img = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_joint_cam = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_mesh_img = np.zeros((self.vertex_num,3), dtype=np.float32) # dummy + smpl_pose = np.zeros((72), dtype=np.float32) # dummy + smpl_shape = np.zeros((10), dtype=np.float32) # dummy + smpl_joint_trunc = np.zeros((self.joint_num,1), dtype=np.float32) + smpl_mesh_trunc = np.zeros((self.vertex_num,1), dtype=np.float32) + is_valid_fit = False + + # 3D data rotation augmentation + rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]], dtype=np.float32) + # parameter + smpl_pose = smpl_pose.reshape(-1,3) + root_pose = smpl_pose[self.root_joint_idx,:] + root_pose, _ = cv2.Rodrigues(root_pose) + root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat,root_pose)) + smpl_pose[self.root_joint_idx] = root_pose.reshape(3) + smpl_pose = smpl_pose.reshape(-1) + # smpl coordinate + smpl_joint_cam = smpl_joint_cam - smpl_joint_cam[self.root_joint_idx,None] # root-relative + smpl_joint_cam = np.dot(rot_aug_mat, smpl_joint_cam.transpose(1,0)).transpose(1,0) + + # SMPL pose parameter validity + smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32) + for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'): + smpl_param_valid[self.joints_name.index(name)] = 0 + smpl_param_valid = smpl_param_valid.reshape(-1) + + inputs = {'img': img, 'joints': input_coco_joint_img[:, :2], 'joints_mask': joint_mask} + targets = {'orig_joint_img': coco_joint_img, 'fit_joint_img': smpl_joint_img, 'orig_joint_cam': coco_joint_cam, 'fit_joint_cam': smpl_joint_cam, 'pose_param': smpl_pose, 'shape_param': smpl_shape} + meta_info = {'orig_joint_valid': coco_joint_valid, 'orig_joint_trunc': coco_joint_trunc, 'fit_param_valid': smpl_param_valid, 'fit_joint_trunc': smpl_joint_trunc, 'is_valid_fit': float(is_valid_fit), 'bbox': bbox, + 'is_3D': float(False)} + return inputs, targets, meta_info + else: + # coco gt + coco_joint_img = data['joint_img'] + coco_joint_valid = data['joint_valid'] + + coco_joint_img_xy1 = np.concatenate((coco_joint_img[:, :2], np.ones_like(coco_joint_img[:, :1])), 1) + coco_joint_img[:, :2] = np.dot(img2bb_trans, coco_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + coco_joint_img[:, 0] = coco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_joint_img[:, 1] = coco_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # check truncation + coco_joint_trunc = coco_joint_valid * ((coco_joint_img[:, 0] >= 0) * (coco_joint_img[:, 0] < cfg.output_hm_shape[2]) * (coco_joint_img[:, 1] >= 0) * (coco_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape( + -1, 1).astype(np.float32) + + coco_joint_img = transform_joint_to_other_db(coco_joint_img, self.coco_joints_name, self.joints_name) + coco_joint_trunc = transform_joint_to_other_db(coco_joint_trunc, self.coco_joints_name, self.joints_name) + + inputs = {'img': img, 'joints': coco_joint_img, 'joints_mask': coco_joint_trunc} + targets = {} + meta_info = {'bbox': bbox} + return inputs, targets, meta_info + + def evaluate(self, outs, cur_sample_idx): + + annots = self.datalist + sample_num = len(outs) + eval_result = {} + for n in range(sample_num): + annot = annots[cur_sample_idx + n] + out = outs[n] + + img = cv2.imread(annot['img_path']) + mesh_cam_render = out['mesh_cam_render'] + bbox = out['bbox'] + princpt = (bbox[0]+bbox[2]/2, bbox[1]+bbox[3]/2) + img = vis_bbox(img, bbox, alpha=1) + + rendered_img = render_mesh(img, mesh_cam_render, self.face, {'focal': cfg.focal, 'princpt': princpt}) + + cv2.imshow(annot['img_path'], rendered_img/255) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + + return eval_result + + def print_eval_result(self, eval_result): + pass diff --git a/data_processing/data/MuCo/MuCo.py b/data_processing/data/MuCo/MuCo.py new file mode 100644 index 0000000..cd1d479 --- /dev/null +++ b/data_processing/data/MuCo/MuCo.py @@ -0,0 +1,354 @@ +import os +import os.path as osp +import numpy as np +import torch +import cv2 +import random +import json +import math +import copy +import pickle +import transforms3d +from pycocotools.coco import COCO +from config import cfg +from utils.posefix import replace_joint_img +from utils.smpl import SMPL +from utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation, compute_iou +from utils.transforms import world2cam, cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db +from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton, vis_bbox +import transforms3d + + +class MuCo(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = data_split + self.img_dir = osp.join(cfg.data_dir, 'MuCo', 'data') + self.annot_path = osp.join(cfg.data_dir, 'MuCo', 'data', 'MuCo-3DHP.json') + self.smpl_param_path = osp.join(cfg.data_dir, 'MuCo', 'data', 'smpl_param.json') + self.fitting_thr = 25 # milimeter + + # COCO joint set + self.coco_joint_num = 17 # original: 17 + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle') + + # MuCo joint set + self.muco_joint_num = 21 + self.muco_joints_name = ('Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', 'Head', 'R_Hand', 'L_Hand', 'R_Toe', 'L_Toe') + self.muco_flip_pairs = ( (2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13), (17, 18), (19, 20) ) + self.muco_skeleton = ( (0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9), (9, 10), (10, 19), (11, 12), (12, 13), (13, 20), (1, 2), (2, 3), (3, 4), (4, 17), (1, 5), (5, 6), (6, 7), (7, 18) ) + self.muco_root_joint_idx = self.muco_joints_name.index('Pelvis') + self.muco_coco_common_jidx = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13) + + # H36M joint set + self.h36m_joint_regressor = np.load(osp.join(cfg.data_dir, 'Human36M', 'J_regressor_h36m_correct.npy')) # use h36m joint regrssor (only use subset from original muco joint set) + self.h36m_flip_pairs = ( (1, 4), (2, 5), (3, 6), (14, 11), (15, 12), (16, 13) ) + self.h36m_joints_name = ('Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') + self.h36m_root_joint_idx = self.h36m_joints_name.index('Pelvis') + + # SMPL joint set + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.flip_pairs = self.smpl.flip_pairs + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + self.datalist = self.load_data() + print("muco data len: ", len(self.datalist)) + + def load_data(self): + if self.data_split == 'train': + db = COCO(self.annot_path) + with open(self.smpl_param_path) as f: + smpl_params = json.load(f) + else: + print('Unknown data subset') + assert 0 + + datalist = [] + for iid in db.imgs.keys(): + img = db.imgs[iid] + img_id = img["id"] + img_width, img_height = img['width'], img['height'] + imgname = img['file_name'] + img_path = osp.join(self.img_dir, imgname) + focal = img["f"] + princpt = img["c"] + cam_param = {'focal': focal, 'princpt': princpt} + + # crop the closest person to the camera + ann_ids = db.getAnnIds(img_id) + anns = db.loadAnns(ann_ids) + + root_depths = [ann['keypoints_cam'][self.muco_root_joint_idx][2] for ann in anns] + closest_pid = root_depths.index(min(root_depths)) + pid_list = [closest_pid] + for pid in pid_list: + joint_cam = np.array(anns[pid]['keypoints_cam']) + joint_img = np.array(anns[pid]['keypoints_img']) + joint_img = np.concatenate([joint_img, joint_cam[:,2:]],1) + joint_valid = np.ones((self.muco_joint_num,1)) + + if cfg.use_bbox_in_ann: + tight_bbox = np.array(anns[pid]['bbox']) + else: + tight_bbox = get_bbox(joint_img, np.ones_like(joint_img[:, 0]), crop_bottom_body=True) + + # for swap + num_overlap = 0 + near_joints = [] + other_persons = anns[:pid] + anns[pid+1:] + for other in other_persons: + other_tight_bbox = np.array(other['bbox']) + iou = compute_iou(tight_bbox[None, :], other_tight_bbox[None, :]) + if iou < 0.1: + continue + num_overlap += 1 + other_joint = np.array(other['keypoints_img']) + other_joint = np.concatenate((other_joint, np.ones_like(other_joint[:, :1])), axis=1) + other_joint = transform_joint_to_other_db(other_joint, self.muco_joints_name, self.coco_joints_name) + near_joints.append(other_joint) + if num_overlap == 0: + near_joints = [] + + # bbox = process_bbox(tight_bbox, img_width, img_height) + # if bbox is None: continue + + # check smpl parameter exist + try: + smpl_param = smpl_params[str(ann_ids[pid])] + except KeyError: + smpl_param = None + + datalist.append({ + 'img_path': img_path, + 'img_shape': (img_height, img_width), + #'bbox': bbox, + 'tight_bbox': tight_bbox, + 'joint_img': joint_img, + 'joint_cam': joint_cam, + 'joint_valid': joint_valid, + 'cam_param': cam_param, + 'smpl_param': smpl_param, + 'near_joints': near_joints, + 'num_overlap': num_overlap + }) + + return datalist + + def get_smpl_coord(self, smpl_param, cam_param, do_flip, img_shape): + pose, shape, trans = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'] + smpl_pose = torch.FloatTensor(pose).view(1,-1); smpl_shape = torch.FloatTensor(shape).view(1,-1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + smpl_trans = torch.FloatTensor(trans).view(1,3) # translation vector from smpl coordinate to muco world coordinate + + # flip smpl pose parameter (axis-angle) + if do_flip: + smpl_pose = smpl_pose.view(-1,3) + for pair in self.flip_pairs: + if pair[0] < len(smpl_pose) and pair[1] < len(smpl_pose): # face keypoints are already included in self.flip_pairs. However, they are not included in smpl_pose. + smpl_pose[pair[0], :], smpl_pose[pair[1], :] = smpl_pose[pair[1], :].clone(), smpl_pose[pair[0], :].clone() + smpl_pose[:,1:3] *= -1; # multiply -1 to y and z axis of axis-angle + smpl_pose = smpl_pose.view(1,-1) + + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer['neutral'](smpl_pose, smpl_shape, smpl_trans) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1,3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + # flip translation + if do_flip: # avg of old and new root joint should be image center. + focal, princpt = cam_param['focal'], cam_param['princpt'] + flip_trans_x = 2 * (((img_shape[1] - 1)/2. - princpt[0]) / focal[0] * (smpl_joint_coord[self.root_joint_idx,2] * 1000)) / 1000 - 2 * smpl_joint_coord[self.root_joint_idx][0] + smpl_mesh_coord[:,0] += flip_trans_x + smpl_joint_coord[:,0] += flip_trans_x + + # change to mean shape if beta is too far from it + smpl_shape[(smpl_shape.abs() > 3).any(dim=1)] = 0. + + # meter -> milimeter + smpl_mesh_coord *= 1000; smpl_joint_coord *= 1000; + return smpl_mesh_coord, smpl_joint_coord, smpl_pose[0].numpy(), smpl_shape[0].numpy() + + def get_fitting_error(self, muco_joint, smpl_mesh, do_flip): + muco_joint = muco_joint.copy() + muco_joint = muco_joint - muco_joint[self.muco_root_joint_idx,None,:] # root-relative + if do_flip: + muco_joint[:,0] = -muco_joint[:,0] + for pair in self.muco_flip_pairs: + muco_joint[pair[0],:] , muco_joint[pair[1],:] = muco_joint[pair[1],:].copy(), muco_joint[pair[0],:].copy() + muco_joint_valid = np.ones((self.muco_joint_num,3), dtype=np.float32) + + # transform to h36m joint set + h36m_joint = transform_joint_to_other_db(muco_joint, self.muco_joints_name, self.h36m_joints_name) + h36m_joint_valid = transform_joint_to_other_db(muco_joint_valid, self.muco_joints_name, self.h36m_joints_name) + h36m_joint = h36m_joint[h36m_joint_valid==1].reshape(-1,3) + + h36m_from_smpl = np.dot(self.h36m_joint_regressor, smpl_mesh) + h36m_from_smpl = h36m_from_smpl[h36m_joint_valid==1].reshape(-1,3) + h36m_from_smpl = h36m_from_smpl - np.mean(h36m_from_smpl,0)[None,:] + np.mean(h36m_joint,0)[None,:] # translation alignment + error = np.sqrt(np.sum((h36m_joint - h36m_from_smpl)**2,1)).mean() + return error + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path, img_shape, tight_bbox, smpl_param, cam_param = data['img_path'], data['img_shape'], data['tight_bbox'], data['smpl_param'], data['cam_param'] + + # img + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, rot, do_flip,bbox = augmentation(img, tight_bbox, self.data_split,is_full_body = True) # always full body + img = self.transform(img.astype(np.float32))/255. + + # muco gt + muco_joint_img = data['joint_img'] + muco_joint_cam = data['joint_cam'] + muco_joint_cam = muco_joint_cam - muco_joint_cam[self.muco_root_joint_idx,None,:] # root-relative + muco_joint_valid = data['joint_valid'] + if do_flip: + muco_joint_img[:,0] = img_shape[1] - 1 - muco_joint_img[:,0] + muco_joint_cam[:,0] = -muco_joint_cam[:,0] + for pair in self.muco_flip_pairs: + muco_joint_img[pair[0],:], muco_joint_img[pair[1],:] = muco_joint_img[pair[1],:].copy(), muco_joint_img[pair[0],:].copy() + muco_joint_cam[pair[0],:], muco_joint_cam[pair[1],:] = muco_joint_cam[pair[1],:].copy(), muco_joint_cam[pair[0],:].copy() + muco_joint_valid[pair[0],:], muco_joint_valid[pair[1],:] = muco_joint_valid[pair[1],:].copy(), muco_joint_valid[pair[0],:].copy() + + muco_joint_img_xy1 = np.concatenate((muco_joint_img[:,:2], np.ones_like(muco_joint_img[:,:1])),1) + muco_joint_img[:,:2] = np.dot(img2bb_trans, muco_joint_img_xy1.transpose(1,0)).transpose(1,0) + # for swap + if len(data['near_joints']) > 0: + near_joint_list = [] + for nj in data['near_joints']: + near_joint = np.ones((self.coco_joint_num, 3), dtype=np.float32) + nj_xy1 = np.concatenate((nj[:, :2], np.ones_like(nj[:, :1])), axis=1) + near_joint[:, :2] = np.dot(img2bb_trans, nj_xy1.transpose(1,0)).transpose(1,0) + near_joint_list.append(near_joint) + near_joints = np.asarray(near_joint_list, dtype=np.float32) + else: + near_joints = np.zeros((1, self.coco_joint_num, 3), dtype=np.float32) + + input_muco_joint_img = muco_joint_img.copy() + muco_joint_img[:,0] = muco_joint_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + muco_joint_img[:,1] = muco_joint_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + muco_joint_img[:,2] = muco_joint_img[:,2] - muco_joint_img[self.muco_root_joint_idx][2] # root-relative + muco_joint_img[:,2] = (muco_joint_img[:,2] / (cfg.bbox_3d_size * 1000 / 2) + 1)/2. * cfg.output_hm_shape[0] # change cfg.bbox_3d_size from meter to milimeter + + # check truncation + muco_joint_trunc = muco_joint_valid * ((muco_joint_img[:,0] >= 0) * (muco_joint_img[:,0] < cfg.output_hm_shape[2]) * \ + (muco_joint_img[:,1] >= 0) * (muco_joint_img[:,1] < cfg.output_hm_shape[1]) * \ + (muco_joint_img[:,2] >= 0) * (muco_joint_img[:,2] < cfg.output_hm_shape[0])).reshape(-1,1).astype(np.float32) + + # transform muco joints to target db joints + muco_joint_img = transform_joint_to_other_db(muco_joint_img, self.muco_joints_name, self.joints_name) + muco_joint_cam = transform_joint_to_other_db(muco_joint_cam, self.muco_joints_name, self.joints_name) + muco_joint_valid = transform_joint_to_other_db(muco_joint_valid, self.muco_joints_name, self.joints_name) + muco_joint_trunc = transform_joint_to_other_db(muco_joint_trunc, self.muco_joints_name, self.joints_name) + + # apply PoseFix + input_muco_joint_img[:, 2] = 1 # joint valid + tmp_joint_img = transform_joint_to_other_db(input_muco_joint_img, self.muco_joints_name, self.coco_joints_name) + tmp_joint_img = replace_joint_img(tmp_joint_img, data['tight_bbox'], near_joints, data['num_overlap'], img2bb_trans) + tmp_joint_img = transform_joint_to_other_db(tmp_joint_img, self.coco_joints_name, self.muco_joints_name) + input_muco_joint_img[self.muco_coco_common_jidx, :2] = tmp_joint_img[self.muco_coco_common_jidx, :2] + """ + # debug PoseFix result + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), input_muco_joint_img.T, self.muco_skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + """ + input_muco_joint_img[:, 0] = input_muco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + input_muco_joint_img[:, 1] = input_muco_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + input_muco_joint_img = transform_joint_to_other_db(input_muco_joint_img, self.muco_joints_name, self.joints_name) + + if smpl_param is not None: + # smpl coordinates + smpl_mesh_cam, smpl_joint_cam, smpl_pose, smpl_shape = self.get_smpl_coord(smpl_param, cam_param, do_flip, img_shape) + smpl_coord_cam = np.concatenate((smpl_mesh_cam, smpl_joint_cam)) + focal, princpt = cam_param['focal'], cam_param['princpt'] + smpl_coord_img = cam2pixel(smpl_coord_cam, focal, princpt) + + # affine transform x,y coordinates. root-relative depth + smpl_coord_img_xy1 = np.concatenate((smpl_coord_img[:,:2], np.ones_like(smpl_coord_img[:,:1])),1) + smpl_coord_img[:,:2] = np.dot(img2bb_trans, smpl_coord_img_xy1.transpose(1,0)).transpose(1,0)[:,:2] + """ + # vis smpl + newimg = vis_keypoints_with_skeleton(img.numpy().transpose(1, 2, 0), smpl_coord_img[6890:].T, self.skeleton) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + """ + smpl_coord_img[:,2] = smpl_coord_img[:,2] - smpl_coord_cam[self.vertex_num + self.root_joint_idx][2] + smpl_coord_img[:,0] = smpl_coord_img[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + smpl_coord_img[:,1] = smpl_coord_img[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + smpl_coord_img[:,2] = (smpl_coord_img[:,2] / (cfg.bbox_3d_size * 1000 / 2) + 1)/2. * cfg.output_hm_shape[0] # change cfg.bbox_3d_size from meter to milimeter + + # check truncation + smpl_trunc = ((smpl_coord_img[:,0] >= 0) * (smpl_coord_img[:,0] < cfg.output_hm_shape[2]) * \ + (smpl_coord_img[:,1] >= 0) * (smpl_coord_img[:,1] < cfg.output_hm_shape[1]) * \ + (smpl_coord_img[:,2] >= 0) * (smpl_coord_img[:,2] < cfg.output_hm_shape[0])).reshape(-1,1).astype(np.float32) + + # split mesh and joint coordinates + smpl_mesh_img = smpl_coord_img[:self.vertex_num]; smpl_joint_img = smpl_coord_img[self.vertex_num:]; + smpl_mesh_trunc = smpl_trunc[:self.vertex_num]; smpl_joint_trunc = smpl_trunc[self.vertex_num:]; + + # if fitted mesh is too far from muco gt, discard it + is_valid_fit = True + error = self.get_fitting_error(data['joint_cam'], smpl_mesh_cam, do_flip) + if error > self.fitting_thr: + is_valid_fit = False + + else: + smpl_joint_img = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_joint_cam = np.zeros((self.joint_num,3), dtype=np.float32) # dummy + smpl_mesh_img = np.zeros((self.vertex_num,3), dtype=np.float32) # dummy + smpl_pose = np.zeros((72), dtype=np.float32) # dummy + smpl_shape = np.zeros((10), dtype=np.float32) # dummy + smpl_joint_trunc = np.zeros((self.joint_num,1), dtype=np.float32) # dummy + smpl_mesh_trunc = np.zeros((self.vertex_num,1), dtype=np.float32) # dummy + is_valid_fit = False + + # 3D data rotation augmentation + rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]], dtype=np.float32) + # muco coordinate + muco_joint_cam = np.dot(rot_aug_mat, muco_joint_cam.transpose(1,0)).transpose(1,0) / 1000 # milimeter to meter + # parameter + smpl_pose = smpl_pose.reshape(-1,3) + root_pose = smpl_pose[self.root_joint_idx,:] + root_pose, _ = cv2.Rodrigues(root_pose) + root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat,root_pose)) + smpl_pose[self.root_joint_idx] = root_pose.reshape(3) + smpl_pose = smpl_pose.reshape(-1) + # smpl coordinate + smpl_joint_cam = smpl_joint_cam - smpl_joint_cam[self.root_joint_idx,None] # root-relative + smpl_joint_cam = np.dot(rot_aug_mat, smpl_joint_cam.transpose(1,0)).transpose(1,0) / 1000 # milimeter to meter + + # SMPL pose parameter validity + smpl_param_valid = np.ones((self.smpl.orig_joint_num, 3), dtype=np.float32) + for name in ('L_Ankle', 'R_Ankle', 'L_Toe', 'R_Toe', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand'): + smpl_param_valid[self.joints_name.index(name)] = 0 + smpl_param_valid = smpl_param_valid.reshape(-1) + + inputs = {'img': img, 'joints': input_muco_joint_img[:, :2], 'joints_mask': muco_joint_trunc} + targets = {'orig_joint_img': muco_joint_img, 'fit_joint_img': smpl_joint_img, 'orig_joint_cam': muco_joint_cam, 'fit_joint_cam': smpl_joint_cam, 'pose_param': smpl_pose, 'shape_param': smpl_shape} + meta_info = {'orig_joint_valid': muco_joint_valid, 'orig_joint_trunc': muco_joint_trunc, 'fit_param_valid': smpl_param_valid, 'fit_joint_trunc': smpl_joint_trunc, 'is_valid_fit': float(is_valid_fit), 'bbox': bbox, + 'is_3D': float(True)} + + return inputs, targets, meta_info + + diff --git a/data_processing/data/MuPoTs/MuPoTs.py b/data_processing/data/MuPoTs/MuPoTs.py new file mode 100644 index 0000000..0ef7943 --- /dev/null +++ b/data_processing/data/MuPoTs/MuPoTs.py @@ -0,0 +1,311 @@ +import torch +import copy +import os +import os.path as osp +import scipy.io as sio +import numpy as np +from pycocotools.coco import COCO +from config import cfg +import json +import cv2 +import random +import math + +from utils.smpl import SMPL +from utils.transforms import pixel2cam, transform_joint_to_other_db, cam2pixel +from utils.preprocessing import load_img, augmentation, process_bbox, get_bbox +from utils.vis import vis_keypoints, vis_3d_skeleton, vis_keypoints_with_skeleton + + +class MuPoTs(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split = data_split + self.img_dir = osp.join(cfg.data_dir, 'MuPoTs', 'data', 'MultiPersonTestSet') + self.test_annot_path = osp.join(cfg.data_dir, 'MuPoTs', 'data', 'MuPoTS-3D.json') + self.hhrnet_result_path = osp.join(cfg.data_dir, 'MuPoTs', 'data', 'MuPoTs_test_hhrnet_result.json') + self.hhrnet_thr = 0.1 + self.openpose_result_path = osp.join(cfg.data_dir, 'MuPoTs', 'data', 'MuPoTs_test_openpose_result.json') + self.openpose_thr = 0.05 + + # SMPL joint set + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + # MuCo-3DHP + self.muco_joint_num = 21 + self.muco_joints_name = ( + 'Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', + 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', 'Head', 'R_Hand', 'L_Hand', 'R_Toe', + 'L_Toe') + + # MuPoTS + self.mupots_joint_num = 17 + self.mupots_joints_name = ( + 'Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', + 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', 'Head') # + self.mupots_flip_pairs = ((2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13)) + self.mupots_skeleton = ( + (0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9), (9, 10), (11, 12), (12, 13), (1, 2), (2, 3), + (3, 4), (1, 5), (5, 6), (6, 7)) + self.mupots_eval_joint = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + self.mupots_root_idx = self.mupots_joints_name.index('Pelvis') + + # H36M joint set + # Spine Thorax, Head + self.h36m_joint_num = 17 + self.h36m_joints_name = ('Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Spine', 'Thorax', 'Head', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') + self.h36m_flip_pairs = ((1, 4), (2, 5), (3, 6), (14, 11), (15, 12), (16, 13)) + self.h36m_skeleton = ((0, 7), (7, 8), (8, 9), (9, 10), (8, 11), (11, 12), (12, 13), (8, 14), (14, 15), (15, 16), (0, 1), (1, 2), (2, 3), (0, 4), (4, 5), (5, 6)) + self.h36m_root_joint_idx = self.h36m_joints_name.index('Pelvis') + self.h36m_eval_joint = (1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16) + # self.h36m_joint_regressor = np.load(osp.join('..', 'data', 'Human36M', 'J_regressor_h36m_from_pav.npy')) #'J_regressor_h36m_correct.npy')) + # self.h36m_pav_joints_name = ('Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Hip', 'R_Knee', 'R_Ankle', 'Spine', 'Thorax', 'Head', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') + # self.h36m_joint_regressor = transform_joint_to_other_db(self.h36m_joint_regressor, self.h36m_pav_joints_name, self.h36m_joints_name) + + # MPI-INF-3DHP joint set + self.mpii3d_joint_num = 17 + self.mpii3d_joints_name = ( + 'Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', + 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', + 'Head') + self.mpii3d_flip_pairs = ((2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13)) + self.mpii3d_smpl_regressor = np.load(osp.join(cfg.data_dir, 'MPI_INF_3DHP', 'J_regressor_mi_smpl.npy'))[:17] + self.mpii3d_root_idx = self.mpii3d_joints_name.index('Pelvis') + + # MSCOCO joint set + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') + # OpenPose joint set + self.openpose_joints_name = ('Nose', 'Neck', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Eye', 'L_Eye', 'R_Ear', 'L_Ear', 'Pelvis') + + self.datalist = self.load_data() + print('mupots data len: ', len(self.datalist)) + + def add_pelvis(self, joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx,2] * joint_coord[rhip_idx,2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis))#, neck)) + + return joint_coord + + def add_neck(self, joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1,3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + def load_data(self): + if self.data_split != 'test': + print('Unknown data subset') + assert 0 + + with open(self.hhrnet_result_path) as f: + hhrnet_result = json.load(f) + with open(self.openpose_result_path) as f: + openpose_result = json.load(f) + + data = [] + db = COCO(self.test_annot_path) + + count_dummy = 0 + # use gt bbox and root + print("Get bounding box and root from groundtruth") + for aid in db.anns.keys(): + ann = db.anns[aid] + if ann['is_valid'] == 0: + continue + + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + img_path = osp.join(self.img_dir, img['file_name']) + fx, fy, cx, cy = img['intrinsic'] + f = np.array([fx, fy]); + c = np.array([cx, cy]); + + joint_cam = np.array(ann['keypoints_cam']) + root_cam = joint_cam[self.mupots_root_idx] + + joint_img = np.array(ann['keypoints_img']) + joint_img = np.concatenate([joint_img, joint_cam[:, 2:]], 1) + joint_img[:, 2] = joint_img[:, 2] - root_cam[2] + joint_valid = np.ones((self.mupots_joint_num, 1)) + + hhrnetpose = np.array(hhrnet_result[str(aid)]['coco_joints']) + hhrnetpose = self.add_pelvis(hhrnetpose, self.coco_joints_name) + hhrnetpose = self.add_neck(hhrnetpose, self.coco_joints_name) + + openpose = np.array(openpose_result[str(aid)]['coco_joints']) + openpose = self.add_pelvis(openpose, self.openpose_joints_name) + + if openpose.sum() == 0: + count_dummy += 1 + bbox = np.array(ann['bbox']) + img_width, img_height = img['width'], img['height'] + # bbox = process_bbox(bbox, img_width, img_height) + # if bbox is None: continue + + data.append({ + 'img_path': img_path, + 'img_shape': (img_height, img_width), + 'bbox': bbox, + 'tight_bbox': np.array(ann['bbox']), + 'joint_img': joint_img, # [org_img_x, org_img_y, depth - root_depth] + 'joint_cam': joint_cam, # [X, Y, Z] in camera coordinate + 'joint_valid': joint_valid, + 'root_cam': root_cam, # [X, Y, Z] in camera coordinate + 'f': f, + 'c': c, + 'hhrnetpose': hhrnetpose, + 'openpose': openpose + }) + + print("dummy predictions: ", count_dummy) + return data + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_path = data['img_path'] + + input_joint_name = 'openpose' + if input_joint_name == 'gt': + joint_coord_img = data['joint_img'] + joint_coord_img[:, 2] = 1 + joint_valid = data['joint_valid'] + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.mupots_joints_name, self.joints_name) + joint_valid = transform_joint_to_other_db(joint_valid, self.mupots_joints_name, self.joints_name) + elif input_joint_name == 'hhrnet': + joint_coord_img = data['hhrnetpose'] + joint_valid = (joint_coord_img[:, 2:] > self.hhrnet_thr) + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.coco_joints_name, self.joints_name) + joint_valid = transform_joint_to_other_db(joint_valid, self.coco_joints_name, self.joints_name) + elif input_joint_name == 'openpose': + joint_coord_img = data['openpose'] + joint_valid = (joint_coord_img[:, 2:] > self.openpose_thr) + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.openpose_joints_name, self.joints_name) + joint_valid = transform_joint_to_other_db(joint_valid, self.openpose_joints_name, self.joints_name) + + # get bbox from joints + try: + bbox = get_bbox(joint_coord_img, joint_valid[:, 0]) + except: # in case of perfect occlusion + bbox = data['bbox'] + img_height, img_width = data['img_shape'] + bbox = process_bbox(bbox.copy(), img_width, img_height, is_3dpw_test=True) + + # img + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, _, _, _ = augmentation(img, bbox, self.data_split) + img = self.transform(img.astype(np.float32)) / 255. + + """ + # debug + img = cv2.imread(img_path) + input_img = vis_keypoints_with_skeleton(img, joint_coord_img.T, self.skeleton, kp_thresh=0.1, alpha=1, kps_scores=joint_coord_img[:, 2:].round(3)) + cv2.imshow('mupots', input_img) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + # import pdb; + # pdb.set_trace() + """ + + # x,y affine transform, root-relative depth + joint_coord_img_xy1 = np.concatenate((joint_coord_img[:, :2], np.ones_like(joint_coord_img[:, 0:1])), 1) + joint_coord_img[:, :2] = np.dot(img2bb_trans, joint_coord_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2] + joint_coord_img[:, 0] = joint_coord_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + joint_coord_img[:, 1] = joint_coord_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # check truncation + joints_mask = joint_valid * ( + (joint_coord_img[:, 0] >= 0) * (joint_coord_img[:, 0] < cfg.output_hm_shape[2]) * (joint_coord_img[:, 1] >= 0) * (joint_coord_img[:, 1] < cfg.output_hm_shape[1])).reshape(-1, 1).astype(np.float32) + + inputs = {'img': img, 'joints': joint_coord_img, 'joints_mask': joints_mask} + targets = {} + meta_info = {'bbox': bbox, + 'bb2img_trans': bb2img_trans, 'img2bb_trans': img2bb_trans} + return inputs, targets, meta_info + + def evaluate(self, outs, cur_sample_idx): + gts = self.datalist + sample_num = len(outs) + joint_num = self.mupots_joint_num + + pred_2d_save = {} + pred_3d_save = {} + for n in range(sample_num): + gt = gts[cur_sample_idx+n] + f = gt['f'] + c = gt['c'] + gt_3d_root = gt['root_cam'] + img_name = gt['img_path'].split('/') + img_name = img_name[-2] + '_' + img_name[-1].split('.')[0] # e.g., TS1_img_0001 + + # h36m joint from output mesh + out = outs[n] + mesh_out_cam = out['smpl_mesh_cam'] * 1000 + pred = np.dot(self.mpii3d_smpl_regressor, mesh_out_cam) + pred = pred - pred[self.mpii3d_root_idx, None] # root-relative + pred_3d_kpt = transform_joint_to_other_db(pred, self.mpii3d_joints_name, self.mupots_joints_name) + pred_3d_kpt += gt_3d_root + + pred_3d_save.setdefault(img_name + '_3d', []).append(pred_3d_kpt) + + pred_2d_kpt = cam2pixel(pred_3d_kpt, f, c) + pred_2d_save.setdefault(img_name + '_2d', []).append(pred_2d_kpt[:, :2]) + + vis = False + if vis: + cvimg = cv2.imread(gt['img_path'], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + filename = str(random.randrange(1, 500)) + + pred_2d_kpt[:, 2] = 1 + # tmpimg = vis_keypoints(cvimg, pred_2d_kpt, alpha=1) + tmpimg = vis_keypoints_with_skeleton(cvimg, pred_2d_kpt.T, self.mupots_skeleton, kp_thresh=0.1, alpha=1) + # cv2.imwrite(filename + '_output.jpg', tmpimg) + cv2.imshow('mupots', tmpimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + + eval_result = {**pred_2d_save, **pred_3d_save} + return eval_result + + def print_eval_result(self, eval_result): + pred_2d_save = {} + pred_3d_save = {} + + for k, v in eval_result.items(): + if '2d' in k: + key = k.split('_2d')[0] + pred_2d_save[key] = v + elif '3d' in k: + key = k.split('_3d')[0] + pred_3d_save[key] = v + + result_dir = osp.join(cfg.result_dir, 'MuPoTs') + output_path = osp.join(result_dir, 'preds_2d_kpt_mupots.mat') + sio.savemat(output_path, pred_2d_save) + print("Testing result is saved at " + output_path) + output_path = osp.join(result_dir, 'preds_3d_kpt_mupots.mat') + sio.savemat(output_path, pred_3d_save) + print("Testing result is saved at " + output_path) \ No newline at end of file diff --git a/data_processing/data/PW3D/PW3D.py b/data_processing/data/PW3D/PW3D.py new file mode 100644 index 0000000..76f4375 --- /dev/null +++ b/data_processing/data/PW3D/PW3D.py @@ -0,0 +1,373 @@ +import os +import os.path as osp +import numpy as np +import torch +import cv2 +import random +import json +import math +import copy +import transforms3d +from pycocotools.coco import COCO +from config import cfg +from utils.renderer import Renderer +from utils.smpl import SMPL +from utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation +from utils.transforms import cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db, denorm_joints, convert_crop_cam_to_orig_img +from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton, vis_bbox, render_mesh + + +class PW3D(torch.utils.data.Dataset): + def __init__(self, transform, data_split): + self.transform = transform + self.data_split ='validation' if cfg.crowd else 'test' # data_split + self.data_path = osp.join(cfg.data_dir, 'PW3D', 'data') + self.human_bbox_root_dir = osp.join(cfg.data_dir, 'PW3D', 'rootnet_output', 'bbox_root_pw3d_output.json') + + # SMPL joint set + self.smpl = SMPL() + self.face = self.smpl.face + self.joint_regressor = self.smpl.joint_regressor + self.vertex_num = self.smpl.vertex_num + self.joint_num = self.smpl.joint_num + self.joints_name = self.smpl.joints_name + self.skeleton = self.smpl.skeleton + self.root_joint_idx = self.smpl.root_joint_idx + self.face_kps_vertex = self.smpl.face_kps_vertex + + # H36M joint set + self.h36m_joints_name = ('Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head_top', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') + self.h36m_root_joint_idx = self.h36m_joints_name.index('Pelvis') + self.h36m_eval_joint = (1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16) + self.h36m_joint_regressor = np.load(osp.join(cfg.data_dir, 'Human36M', 'J_regressor_h36m_correct.npy')) + + # mscoco skeleton + self.coco_joint_num = 18+1 # original: 17, manually added pelvis + self.coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') + self.coco_skeleton = ( (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12) ) + self.coco_flip_pairs = ( (1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16) ) + self.coco_joint_regressor = np.load(osp.join(cfg.data_dir, 'MSCOCO', 'J_regressor_coco_hip_smpl.npy')) + + self.openpose_joints_name = ('Nose', 'Neck', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Eye', 'L_Eye', 'R_Ear', 'L_Ear', 'Pelvis') + self.conf_thr = 0.05 + + self.datalist = self.load_data() + print("3dpw data len: ", len(self.datalist)) + + def add_pelvis(self, joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx,2] * joint_coord[rhip_idx,2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + + def add_neck(self, joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1,3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + def load_data(self): + + db = COCO(osp.join(self.data_path, '3DPW_latest_' + self.data_split + '.json')) + if self.data_split == 'test' and not cfg.use_gt_info: + print("Get bounding box and root from " + self.human_bbox_root_dir) + bbox_root_result = {} + with open(self.human_bbox_root_dir) as f: + annot = json.load(f) + for i in range(len(annot)): + ann_id = str(annot[i]['ann_id']) + bbox_root_result[ann_id] = {'bbox': np.array(annot[i]['bbox']), 'root': np.array(annot[i]['root_cam'])} + elif cfg.crowd: + with open(osp.join(self.data_path, f'3DPW_{self.data_split}_crowd_hhrnet_result.json')) as f: + hhrnet_result = json.load(f) + print("Load Higher-HRNet input") + + else: + print("Load OpenPose input") + + hhrnet_count = 0 + datalist = [] + for aid in db.anns.keys(): + aid = int(aid) + + ann = db.anns[aid] + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + img_width, img_height = img['width'], img['height'] + sequence_name = img['sequence'] + img_name = img['file_name'] + + if cfg.crowd and self.data_split=='validation': + if 'courtyard_hug_00' not in sequence_name and 'courtyard_dancing_00' not in sequence_name: + continue + + img_path = osp.join(self.data_path, 'imageFiles', sequence_name, img_name) + cam_param = {k: np.array(v, dtype=np.float32) for k,v in img['cam_param'].items()} + smpl_param = ann['smpl_param'] + + if self.data_split == 'test' and not cfg.use_gt_info: + bbox = bbox_root_result[str(aid)]['bbox'] # bbox should be aspect ratio preserved-extended. It is done in RootNet. + root_joint_depth = bbox_root_result[str(aid)]['root'][2] + else: + ann['bbox'] = np.array(ann['bbox'], dtype=np.float32) + + bbox = process_bbox(ann['bbox'], img['width'], img['height']) + if bbox is None: continue + root_joint_depth = None + + openpose = np.array(ann['openpose_result'], dtype=np.float32).reshape(-1, 3) + openpose = self.add_pelvis(openpose, self.openpose_joints_name) + pose_score_thr = self.conf_thr + + hhrnetpose = None + if cfg.crowd and self.data_split=='validation': + try: + hhrnetpose = np.array(hhrnet_result[str(aid)]['coco_joints']) + hhrnetpose = self.add_pelvis(hhrnetpose, self.coco_joints_name) + hhrnetpose = self.add_neck(hhrnetpose, self.coco_joints_name) + hhrnet_count += 1 + + except: + hhrnetpose = openpose + hhrnetpose = transform_joint_to_other_db(hhrnetpose, self.openpose_joints_name, self.coco_joints_name) + + datalist.append({ + 'ann_id': aid, + 'img_path': img_path, + 'img_shape': (img_height, img_width), + 'bbox': bbox, + 'tight_bbox': ann['bbox'], + 'smpl_param': smpl_param, + 'cam_param': cam_param, + 'root_joint_depth': root_joint_depth, + 'pose_score_thr': pose_score_thr, + 'openpose': openpose, + 'hhrnetpose': hhrnetpose + }) + + print("check hhrnet input: ", hhrnet_count) + return datalist + + def get_smpl_coord(self, smpl_param): + pose, shape, trans, gender = smpl_param['pose'], smpl_param['shape'], smpl_param['trans'], smpl_param['gender'] + smpl_pose = torch.FloatTensor(pose).view(1,-1); smpl_shape = torch.FloatTensor(shape).view(1,-1); # smpl parameters (pose: 72 dimension, shape: 10 dimension) + smpl_trans = torch.FloatTensor(trans).view(-1,3) # translation vector from smpl coordinate to 3dpw camera coordinate + + # TEMP + # gender = 'neutral' + # get mesh and joint coordinates + smpl_mesh_coord, smpl_joint_coord = self.smpl.layer[gender](smpl_pose, smpl_shape, smpl_trans) + + # incorporate face keypoints + smpl_mesh_coord = smpl_mesh_coord.numpy().astype(np.float32).reshape(-1,3); + # smpl_joint_coord = smpl_joint_coord.numpy().astype(np.float32).reshape(-1,3) + # smpl_face_kps_coord = smpl_mesh_coord[self.face_kps_vertex,:].reshape(-1,3) + # smpl_joint_coord = np.concatenate((smpl_joint_coord, smpl_face_kps_coord)) + smpl_joint_coord = np.dot(self.joint_regressor, smpl_mesh_coord) + + return smpl_mesh_coord, smpl_joint_coord + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = copy.deepcopy(self.datalist[idx]) + aid, img_path, bbox, smpl_param, cam_param = data['ann_id'], data['img_path'], data['bbox'], data['smpl_param'], data['cam_param'] + + # get gt img joint from smpl coordinates + smpl_mesh_cam, smpl_joint_cam = self.get_smpl_coord(smpl_param) + smpl_coord_img = cam2pixel(smpl_joint_cam, cam_param['focal'], cam_param['princpt']) + joint_coord_img = smpl_coord_img + joint_valid = np.ones_like(joint_coord_img[:, :1], dtype=np.float32) + + if cfg.crowd and self.data_split == 'validation': + # get input joint img from higher hrnet + joint_coord_img = data['hhrnetpose'] + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.coco_joints_name, self.joints_name) + else: + # get input joint img from openpose + joint_coord_img = data['openpose'] + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.openpose_joints_name, self.joints_name) + pose_thr = data['pose_score_thr'] + joint_valid[joint_coord_img[:, 2] <= pose_thr] = 0 + + # get bbox from joints + bbox = get_bbox(joint_coord_img, joint_valid[:, 0]) + img_height, img_width = data['img_shape'] + bbox = process_bbox(bbox.copy(), img_width, img_height, is_3dpw_test=True) + bbox = data['bbox'] if bbox is None else bbox + + # img + img = load_img(img_path) + img, img2bb_trans, bb2img_trans, _, _ = augmentation(img, bbox, self.data_split) + img = self.transform(img.astype(np.float32))/255. + + """ + # vis + joint_coord_img = transform_joint_to_other_db(joint_coord_img, self.joints_name, self.crowdpose_joints_name) + img = cv2.imread(img_path) + input_img = vis_keypoints_with_skeleton(img, joint_coord_img.T, self.crowdpose_skeleton, kp_thresh=self.openpose_thr, alpha=1, kps_scores=joint_coord_img[:,2:]) + cv2.imshow('open pose', input_img) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + + # smpl_coord_img[:, 2] = 1 + # input_img = vis_keypoints_with_skeleton(img_copy, smpl_coord_img.T, self.skeleton, kp_thresh=0.4, alpha=1) + # cv2.imshow('smpl gt', input_img/255) + # cv2.waitKey(0) + # cv2.destroyAllWindows() + # cv2.waitKey(1) + """ + + # x,y affine transform, root-relative depth + joint_coord_img_xy1 = np.concatenate((joint_coord_img[:, :2], np.ones_like(joint_coord_img[:, 0:1])), 1) + joint_coord_img[:, :2] = np.dot(img2bb_trans, joint_coord_img_xy1.transpose(1, 0)).transpose(1, 0)[:, :2] + joint_coord_img[:, 0] = joint_coord_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + joint_coord_img[:, 1] = joint_coord_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + # check truncation + joint_trunc = joint_valid * ( + (joint_coord_img[:, 0] >= 0) * (joint_coord_img[:, 0] < cfg.output_hm_shape[2]) * \ + (joint_coord_img[:, 1] >= 0) * (joint_coord_img[:, 1] < cfg.output_hm_shape[1])).reshape(-1, 1).astype(np.float32) + + """ + print(f'{img_path} trunc:\n', joint_trunc.nonzero()) + tmp_coord = joint_coord_img[:, :2] * np.array([[cfg.input_img_shape[1] / cfg.output_hm_shape[2], cfg.input_img_shape[0]/ cfg.output_hm_shape[1]]]) + newimg = vis_keypoints(img.numpy().transpose(1,2,0), tmp_coord) + cv2.imshow(f'{img_path}', newimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + """ + + inputs = {'img': img, 'joints': joint_coord_img, 'joints_mask': joint_trunc} + targets = {'smpl_mesh_cam': smpl_mesh_cam} + meta_info = {'bb2img_trans': bb2img_trans, 'img2bb_trans': img2bb_trans, 'bbox': bbox, 'tight_bbox': data['tight_bbox'], 'aid': aid} + return inputs, targets, meta_info + + def evaluate(self, outs, cur_sample_idx): + annots = self.datalist + sample_num = len(outs) + eval_result = {'mpjpe': [], 'pa_mpjpe': [], 'mpvpe': []} + for n in range(sample_num): + annot = annots[cur_sample_idx + n] + out = outs[n] + + # h36m joint from gt mesh + mesh_gt_cam = out['smpl_mesh_cam_target'] + pose_coord_gt_h36m = np.dot(self.h36m_joint_regressor, mesh_gt_cam) + # debug + root_h36m_gt = pose_coord_gt_h36m[self.h36m_root_joint_idx, :] + pose_gt_img = cam2pixel(pose_coord_gt_h36m, annot['cam_param']['focal'], annot['cam_param']['princpt']) + pose_gt_img = transform_joint_to_other_db(pose_gt_img, self.h36m_joints_name, self.smpl.graph_joints_name) + + pose_coord_gt_h36m = pose_coord_gt_h36m - pose_coord_gt_h36m[self.h36m_root_joint_idx, None] # root-relative + pose_coord_gt_h36m = pose_coord_gt_h36m[self.h36m_eval_joint, :] + mesh_gt_cam -= np.dot(self.joint_regressor, mesh_gt_cam)[0, None, :] + + # TEMP: use PositionNet output + # pose_out_img = out['joint_img'] + # pose_out_img = denorm_joints(pose_out_img, out['bb2img_trans']) + # pose_out_img[:, 2] = (pose_out_img[:, 2] / cfg.output_hm_shape[0] * 2. - 1) * (cfg.bbox_3d_size / 2) + root_h36m_gt[None, 2] + # pose_out_cam = pixel2cam(pose_out_img, annot['cam_param']['focal'], annot['cam_param']['princpt']) + # pose_coord_out_h36m = transform_joint_to_other_db(pose_out_cam, self.smpl.graph_joints_name, self.h36m_joints_name) + + # h36m joint from output mesh + mesh_out_cam = out['smpl_mesh_cam'] + pose_coord_out_h36m = np.dot(self.h36m_joint_regressor, mesh_out_cam) + # # debug + # pose_out_img = cam2pixel(pose_coord_out_h36m + root_h36m_gt, annot['cam_param']['focal'], annot['cam_param']['princpt']) + # pose_out_img = transform_joint_to_other_db(pose_out_img, self.h36m_joints_name, self.smpl.graph_joints_name) + + pose_coord_out_h36m = pose_coord_out_h36m - pose_coord_out_h36m[self.h36m_root_joint_idx, None] # root-relative + pose_coord_out_h36m = pose_coord_out_h36m[self.h36m_eval_joint, :] + pose_coord_out_h36m_aligned = rigid_align(pose_coord_out_h36m, pose_coord_gt_h36m) + + eval_result['mpjpe'].append(np.sqrt( + np.sum((pose_coord_out_h36m - pose_coord_gt_h36m) ** 2, 1)).mean() * 1000) # meter -> milimeter + eval_result['pa_mpjpe'].append(np.sqrt(np.sum((pose_coord_out_h36m_aligned - pose_coord_gt_h36m) ** 2, + 1)).mean() * 1000) # meter -> milimeter + mesh_out_cam -= np.dot(self.joint_regressor, mesh_out_cam)[0, None, :] + + # compute MPVPE + mesh_error = np.sqrt(np.sum((mesh_gt_cam - mesh_out_cam) ** 2, 1)).mean() * 1000 + eval_result['mpvpe'].append(mesh_error) + + if cfg.render: + img = cv2.imread(annot['img_path']) + mesh_cam_render = out['mesh_cam_render'] + bbox = out['bbox'] + princpt = (bbox[0]+bbox[2]/2, bbox[1]+bbox[3]/2) + img = vis_bbox(img, bbox, alpha=1) + + rendered_img = render_mesh(img, mesh_cam_render, self.face, {'focal': cfg.focal, 'princpt': princpt}) + + cv2.imshow(annot['img_path'], rendered_img/255) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + + if cfg.vis: + img = cv2.imread(annot['img_path']) + bbox_to_vis = out['bbox'] + + # vis input 2d pose + # pose_out_img = out['input_joints'] + # pose_out_img = denorm_joints(pose_out_img, out['bb2img_trans']) + # pose_scores = pose_out_img[:, 2:].round(3) + # newimg = vis_keypoints_with_skeleton(img.copy(), pose_out_img.T, self.skeleton, kp_thresh=self.openpose_thr, alpha=1, kps_scores=pose_scores) + # newimg = vis_bbox(newimg, bbox_to_vis, alpha=1) + # cv2.imwrite(f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_input_2dpose.jpg', newimg) + + # vis PositionNet output + pose_out_img = out['joint_img'] + pose_scores = (out['joint_score']).round(3) + pose_out_img = denorm_joints(pose_out_img, out['bb2img_trans']) + pose_out_img = np.concatenate((pose_out_img, pose_out_img[:, :1]), axis=1) + newimg = vis_keypoints_with_skeleton(img.copy(), pose_out_img.T, self.smpl.graph_skeleton, kp_thresh=0.4, alpha=1, kps_scores=pose_scores) + newimg = vis_bbox(newimg, bbox_to_vis, alpha=1) + cv2.imwrite(f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_positionnet.jpg', newimg) + + # vis RotationNet output + pose_out_img = out['joint_proj'] + + pose_out_img = denorm_joints(pose_out_img, out['bb2img_trans']) + pose_out_img = np.concatenate((pose_out_img, pose_out_img[:, :1]), axis=1) + newimg = vis_keypoints_with_skeleton(img.copy(), pose_out_img.T, self.skeleton, + kp_thresh=0.4, alpha=1) + newimg = vis_bbox(newimg, bbox_to_vis, alpha=1) + cv2.imwrite(f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_final.jpg', newimg) + + save_obj(mesh_out_cam, self.face, f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_final.obj') + + # vis gt + pose_gt_img[:, 2] = 1 + newimg = vis_keypoints_with_skeleton(img.copy(), pose_gt_img.T, self.smpl.graph_skeleton, + kp_thresh=0.4, alpha=1) + newimg = vis_bbox(newimg, bbox_to_vis, alpha=1) + cv2.imwrite(f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_gt.jpg', newimg) + + save_obj(mesh_gt_cam, self.face, f'./{annot["img_path"].split("_")[-1][:-4]}_{out["aid"]}_gt.obj') + + return eval_result + + def print_eval_result(self, eval_result): + print('MPJPE from mesh: %.2f mm' % np.mean(eval_result['mpjpe'])) + print('PA MPJPE from mesh: %.2f mm' % np.mean(eval_result['pa_mpjpe'])) + print('MPVPE from mesh: %.2f mm' % np.mean(eval_result['mpvpe'])) + + + + diff --git a/data_processing/data/dataset.py b/data_processing/data/dataset.py new file mode 100644 index 0000000..e8fc95e --- /dev/null +++ b/data_processing/data/dataset.py @@ -0,0 +1,40 @@ +import random +import numpy as np +from torch.utils.data.dataset import Dataset +from config import cfg + +class MultipleDatasets(Dataset): + def __init__(self, dbs, make_same_len=True): + self.dbs = dbs + self.db_num = len(self.dbs) + self.max_db_data_num = max([len(db) for db in dbs]) + self.db_len_cumsum = np.cumsum([len(db) for db in dbs]) + self.make_same_len = make_same_len + + def __len__(self): + # all dbs have the same length + if self.make_same_len: + return self.max_db_data_num * self.db_num + # each db has different length + else: + return sum([len(db) for db in self.dbs]) + + def __getitem__(self, index): + if self.make_same_len: + db_idx = index // self.max_db_data_num + data_idx = index % self.max_db_data_num + if data_idx >= len(self.dbs[db_idx]) * (self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling + data_idx = random.randint(0,len(self.dbs[db_idx])-1) + else: # before last batch: use modular + data_idx = data_idx % len(self.dbs[db_idx]) + else: + for i in range(self.db_num): + if index < self.db_len_cumsum[i]: + db_idx = i + break + if db_idx == 0: + data_idx = index + else: + data_idx = index - self.db_len_cumsum[db_idx-1] + + return self.dbs[db_idx][data_idx] diff --git a/data_processing/demo/backup/template_mesh.npy b/data_processing/demo/backup/template_mesh.npy new file mode 100644 index 0000000..333dbb6 Binary files /dev/null and b/data_processing/demo/backup/template_mesh.npy differ diff --git a/data_processing/demo/backup/template_mesh_in_pyrender.npy b/data_processing/demo/backup/template_mesh_in_pyrender.npy new file mode 100644 index 0000000..e16e4e5 Binary files /dev/null and b/data_processing/demo/backup/template_mesh_in_pyrender.npy differ diff --git a/data_processing/demo/demo.py b/data_processing/demo/demo.py new file mode 100644 index 0000000..84c4f4f --- /dev/null +++ b/data_processing/demo/demo.py @@ -0,0 +1,224 @@ +import glob +import sys +import os +import os.path as osp +import argparse +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import matplotlib.pyplot as plt + + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +from model import get_model +from utils.preprocessing import process_bbox, generate_patch_image, get_bbox +from utils.transforms import pixel2cam, cam2pixel, transform_joint_to_other_db +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL + + + +def add_pelvis(joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + +def add_neck(joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1,3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--img_idx', type=str, default='101570') + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + +# argument parsing +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( +'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', +'Neck', 'L_Thorax', 'R_Thorax', +'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', +'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( +(1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( +(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), (17, 19), +(19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), (24, 26), +(25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face + +# other joint set +coco_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( +(1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), +(11, 17), (12,17), (17,18)) + +vis_joints_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ((0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +# snapshot load +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + +# prepare input image +transform = transforms.ToTensor() +pose2d_result_path = './input/2d_pose_result.json' +with open(pose2d_result_path) as f: + pose2d_result = json.load(f) + +img_dir = './input/images' +for img_name in sorted(pose2d_result.keys()): + img_path = osp.join(img_dir, img_name) + original_img = cv2.imread(img_path) + input = original_img.copy() + input2 = original_img.copy() + original_img_height, original_img_width = original_img.shape[:2] + coco_joint_list = pose2d_result[img_name] + + if args.img_idx not in img_name: + continue + + drawn_joints = [] + c = coco_joint_list + # manually assign the order of output meshes + # coco_joint_list = [c[2], c[0], c[1], c[4], c[3]] + + for idx in range(len(coco_joint_list)): + """ 2D pose input setting & hard-coding for filtering """ + pose_thr = 0.1 + coco_joint_img = np.asarray(coco_joint_list[idx])[:, :3] + coco_joint_img = add_pelvis(coco_joint_img, coco_joints_name) + coco_joint_img = add_neck(coco_joint_img, coco_joints_name) + coco_joint_valid = (coco_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + + # filter inaccurate inputs + det_score = sum(coco_joint_img[:, 2]) + if det_score < 1.0: + continue + if len(coco_joint_img[:, 2:].nonzero()[0]) < 1: + continue + # filter the same targets + tmp_joint_img = coco_joint_img.copy() + continue_check = False + for ddx in range(len(drawn_joints)): + drawn_joint_img = drawn_joints[ddx] + drawn_joint_val = (drawn_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + diff = np.abs(tmp_joint_img[:, :2] - drawn_joint_img[:, :2]) * coco_joint_valid * drawn_joint_val + diff = diff[diff != 0] + if diff.size == 0: + continue_check = True + elif diff.mean() < 20: + continue_check = True + if continue_check: + continue + drawn_joints.append(tmp_joint_img) + + """ Prepare model input """ + # prepare bbox + bbox = get_bbox(coco_joint_img, coco_joint_valid[:, 0]) # xmin, ymin, width, height + bbox = process_bbox(bbox, original_img_width, original_img_height) + if bbox is None: + continue + img, img2bb_trans, bb2img_trans = generate_patch_image(input2[:,:,::-1], bbox, 1.0, 0.0, False, cfg.input_img_shape) + img = transform(img.astype(np.float32))/255 + img = img.cuda()[None,:,:,:] + + coco_joint_img_xy1 = np.concatenate((coco_joint_img[:, :2], np.ones_like(coco_joint_img[:, :1])), 1) + coco_joint_img[:, :2] = np.dot(img2bb_trans, coco_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + coco_joint_img[:, 0] = coco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_joint_img[:, 1] = coco_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + coco_joint_img = transform_joint_to_other_db(coco_joint_img, coco_joints_name, joints_name) + coco_joint_valid = transform_joint_to_other_db(coco_joint_valid, coco_joints_name, joints_name) + coco_joint_valid[coco_joint_img[:, 2] <= pose_thr] = 0 + + # check truncation + coco_joint_trunc = coco_joint_valid * ((coco_joint_img[:, 0] >= 0) * (coco_joint_img[:, 0] < cfg.output_hm_shape[2]) * (coco_joint_img[:, 1] >= 0) * (coco_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape( + -1, 1).astype(np.float32) + coco_joint_img, coco_joint_trunc, bbox = torch.from_numpy(coco_joint_img).cuda()[None, :, :], torch.from_numpy(coco_joint_trunc).cuda()[None, :, :], torch.from_numpy(bbox).cuda()[None, :] + + """ Model forward """ + inputs = {'img': img, 'joints': coco_joint_img, 'joints_mask': coco_joint_trunc} + targets = {} + meta_info = {'bbox': bbox} + with torch.no_grad(): + out = model(inputs, targets, meta_info, 'test') + + # draw output mesh + mesh_cam_render = out['mesh_cam_render'][0].cpu().numpy() + bbox = out['bbox'][0].cpu().numpy() + princpt = (bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2) + # original_img = vis_bbox(original_img, bbox, alpha=1) # for debug + + # generate random color + color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + original_img = render_mesh(original_img, mesh_cam_render, face, {'focal': cfg.focal, 'princpt': princpt}, color=color) + + # Save output mesh + output_dir = 'output' + file_name = f'{output_dir}/{img_path.split("/")[-1][:-4]}_{idx}.jpg' + print("file name: ", file_name) + save_obj(mesh_cam_render, face, file_name=f'{output_dir}/{img_path.split("/")[-1][:-4]}_{idx}.obj') + cv2.imwrite(file_name, original_img) + + # Draw input 2d pose + tmp_joint_img[-1], tmp_joint_img[-2] = tmp_joint_img[-2].copy(), tmp_joint_img[-1].copy() + input = vis_coco_skeleton(input, tmp_joint_img.T, vis_skeleton) + cv2.imwrite(file_name[:-4] + '_2dpose.jpg', input) + + diff --git a/data_processing/demo/extract_camera_parameter.py b/data_processing/demo/extract_camera_parameter.py new file mode 100644 index 0000000..863e450 --- /dev/null +++ b/data_processing/demo/extract_camera_parameter.py @@ -0,0 +1,589 @@ +import glob +import shutil +import sys +import os +import os.path as osp +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import pyrender + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +from model import get_model +from utils.preprocessing import process_bbox, generate_patch_image, get_bbox +from utils.transforms import pixel2cam, cam2pixel, transform_joint_to_other_db +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton +import atexit +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL + +import os +# os.environ["PYOPENGL_PLATFORM"] = "egl" +# check if on a Linux machine +if os.name == 'posix': # Linux + os.environ["PYOPENGL_PLATFORM"] = "egl" +def add_pelvis(joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + + +def add_neck(joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--input_dir', type=str, default='') + parser.add_argument('--output_dir', type=str, default='output') + parser.add_argument('--data_dir', type=str, default='101570') + parser.add_argument('--crop_image_size', type=int, default=1024) + parser.add_argument('--debug', type=int, default=0) + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +def bad_image_vis(image, joint,vis_skeleton): + joint[-1], joint[-2] = joint[-2].copy(), joint[-1].copy() + image = vis_coco_skeleton(image, joint.T, vis_skeleton) + image = cv2.resize(image, (512, int(image.shape[0]/image.shape[1] *512))) + return image + + + + +# argument parsing +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.set_data_dir(args.data_dir) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( + (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( + (0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), + (17, 19), + (19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), + (24, 26), + (25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face + +alpha = 0.8 + +# other joint set +coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( + (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), + (5, 6), + (11, 17), (12, 17), (17, 18)) + +vis_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ( + (0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), + (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +# snapshot load +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + +# prepare input image +transform = transforms.ToTensor() +pose2d_result_path = os.path.join(args.input_dir, '2d_pose_result_hrnet.json') +with open(pose2d_result_path) as f: + pose2d_result = json.load(f) + + +head_bbox_path = os.path.join(args.input_dir, 'head_bbox_yolov5_crowdhuman.json') +with open(head_bbox_path) as f: + head_bbox_result = json.load(f) + +img_dir = os.path.join(args.input_dir, 'images') + +output_dir = args.output_dir +print('>>>>>>> output_dir', output_dir) +os.makedirs(output_dir, exist_ok=True) + +aligned_images_dir = os.path.join(output_dir, 'aligned_images') +os.makedirs(aligned_images_dir, exist_ok=True) + +bad_images_dir = os.path.join(output_dir, 'bad_images') +os.makedirs(bad_images_dir, exist_ok=True) + +visualization_dir = os.path.join(output_dir, 'visualization') +os.makedirs(visualization_dir, exist_ok=True) + + +result_json_path = os.path.join(output_dir, 'result.json') +if os.path.exists(result_json_path): + with open(result_json_path, 'r') as f: + result_json = json.load(f) +else: + result_json = {} + +def exit_function(): + global result_json + with open(result_json_path, 'w') as f: + json.dump(result_json, f) + print('结束') + +atexit.register(exit_function) + + + +if not os.path.exists('./template_mesh.npy'): + print( + f'save template mesh (shape {model.module.template_mesh_cam_render_no_flip.cpu().numpy().shape}) to ./template_mesh.npy') + np.save('./template_mesh.npy', model.module.template_mesh_cam_render_no_flip.cpu().numpy()) + template_mesh = model.module.template_mesh_cam_render_no_flip.cpu().numpy() +else: + print('load template_mesh from ', './template_mesh.npy') + template_mesh = np.load('./template_mesh.npy') + + +if not os.path.exists('./template_mesh_in_pyrender.npy'): + print( + f'save template mesh (shape {model.module.template_mesh_cam_render.cpu().numpy().shape}) to ./template_mesh_in_pyrender.npy') + np.save('./template_mesh_in_pyrender.npy', model.module.template_mesh_cam_render.cpu().numpy()) + + +min_box_stride = 50 + +model.module.set_min_box_stride(min_box_stride) + +image_list = glob.glob(os.path.join(img_dir, "*")) +for img_idx,img_path in enumerate(image_list): + + print(f'{img_idx}/{len(image_list)}',img_path) + original_img = cv2.imread(img_path) + img_name = os.path.basename(img_path) + if img_name not in pose2d_result or img_name not in head_bbox_result: + raise ValueError('please generate 2d pose result and head bbox result for all images first!') + # print(img_name) + # debug + # if img_name.split('.')[0] not in ['pexels-photo-15829424']: + # continue + + original_img_height, original_img_width = original_img.shape[:2] + coco_joint_list = pose2d_result[img_name] + head_bbox_list = head_bbox_result[img_name] + if len(coco_joint_list) > 50: + coco_joint_list = coco_joint_list[:50] + head_bbox_list = head_bbox_list[:50] + + assert len(coco_joint_list) == len(head_bbox_list), 'len(coco_joint_list) != len(head_bbox_list)' + + drawn_joints = [] + c = coco_joint_list + + result_count = 0 + + used_joints = [] + + for idx in range(len(coco_joint_list)): + + image_name = os.path.basename(img_path).split('.')[0] + file_name = f'{image_name}_{idx}.jpg' + + + if f'{image_name}_{idx}.png' in result_json or f'{image_name}_{idx}_h.png' in result_json or f'{image_name}_{idx}_s.png' in result_json: + result_count += 1 + continue + + + image = original_img.copy() + input = original_img.copy() + input2 = original_img.copy() + """ 2D pose input setting & hard-coding for filtering """ + pose_thr = 0.05 + coco_joint_img = np.asarray(coco_joint_list[idx])[:, :3] + + # if there is a similar joint in used_joints, skip this joint + if len(used_joints) > 0: + for joint in used_joints: + #print(np.linalg.norm(joint - coco_joint_img)/ np.linalg.norm(coco_joint_img)) + distance = max( + max(coco_joint_img[:, 0])-min(coco_joint_img[:, 0]), + max(coco_joint_img[:, 1])-min(coco_joint_img[:, 1]) + ) + #print( np.linalg.norm(joint - coco_joint_img)/ distance) + if np.linalg.norm(joint - coco_joint_img)/ distance < 0.15: + print('skip similar', np.linalg.norm(joint - coco_joint_img) / np.linalg.norm(coco_joint_img)) + continue + used_joints.append(coco_joint_img) + + coco_joint_img = add_pelvis(coco_joint_img, coco_joints_name) + coco_joint_img = add_neck(coco_joint_img, coco_joints_name) + coco_joint_valid = (coco_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + + """head bbox""" + head_bbox = head_bbox_list[idx] + + # if len(head_bbox)<4: + # # bad_vis = bad_image_vis(image, coco_joint_img.copy(), vis_skeleton) + # # cv2.imwrite(os.path.join(bad_images_dir, file_name), bad_vis) + # continue + # filter inaccurate inputs + det_score = sum(coco_joint_img[:, 2]) + if det_score < 0.3: + print('skip low det score', det_score) + continue + if len(coco_joint_img[:, 2:].nonzero()[0]) < 1: + print('skip no det score', det_score) + continue + # filter the same targets + tmp_joint_img = coco_joint_img.copy() + continue_check = False + for ddx in range(len(drawn_joints)): + drawn_joint_img = drawn_joints[ddx] + drawn_joint_val = (drawn_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + diff = np.abs(tmp_joint_img[:, :2] - drawn_joint_img[:, :2]) * coco_joint_valid * drawn_joint_val + diff = diff[diff != 0] + if diff.size == 0: + continue_check = True + elif diff.mean() < 20: + continue_check = True + if continue_check: + print('skip continue_check') + # bad_vis = bad_image_vis(image, coco_joint_img.copy(), vis_skeleton) + # cv2.imwrite(os.path.join(bad_images_dir, file_name), bad_vis) + continue + + + drawn_joints.append(tmp_joint_img) + + tmp_joint_img[-1], tmp_joint_img[-2] = tmp_joint_img[-2].copy(), tmp_joint_img[-1].copy() + + + + + """ Prepare model input """ + # prepare bbox + # bbox = get_bbox(coco_joint_img, coco_joint_valid[:, 0]) # xmin, ymin, width, height + bbox = get_bbox(coco_joint_img, np.ones_like(coco_joint_valid[:, 0])) + if bbox[2] < min_box_stride or bbox[3] < min_box_stride: + print('skip too small bbox', bbox[2], bbox[3]) + continue + orig_bbox = bbox.copy() + bbox = process_bbox(bbox, original_img_width, original_img_height) + if bbox is None: + print('skip invalid bbox') + continue + img, img2bb_trans, bb2img_trans = generate_patch_image(input2[:, :, ::-1], bbox, 1.0, 0.0, False, + cfg.input_img_shape) + img = transform(img.astype(np.float32)) / 255 + img = img.cuda()[None, :, :, :] + + coco_joint_img_xy1 = np.concatenate((coco_joint_img[:, :2], np.ones_like(coco_joint_img[:, :1])), 1) + coco_joint_img[:, :2] = np.dot(img2bb_trans, coco_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + coco_joint_img[:, 0] = coco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_joint_img[:, 1] = coco_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + coco_joint_img = transform_joint_to_other_db(coco_joint_img, coco_joints_name, joints_name) + coco_joint_valid = transform_joint_to_other_db(coco_joint_valid, coco_joints_name, joints_name) + coco_joint_valid[coco_joint_img[:, 2] <= pose_thr] = 0 + + # check truncation + coco_joint_trunc = coco_joint_valid * ( + (coco_joint_img[:, 0] >= 0) * (coco_joint_img[:, 0] < cfg.output_hm_shape[2]) * ( + coco_joint_img[:, 1] >= 0) * (coco_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape( + -1, 1).astype(np.float32) + coco_joint_img, coco_joint_trunc, bbox = torch.from_numpy(coco_joint_img).cuda()[None, :, :], torch.from_numpy( + coco_joint_trunc).cuda()[None, :, :], torch.from_numpy(bbox).cuda()[None, :] + + """ Model forward """ + inputs = {'img': img, 'joints': coco_joint_img, 'joints_mask': coco_joint_trunc} + targets = {} + meta_info = {'bbox': bbox} + with torch.no_grad(): + out = model(inputs, targets, meta_info, 'test') + + + #print("file name: ", file_name) + + bbox = out['bbox'][0].cpu().numpy() + princpt = (bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2) + + + + color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + + intrisics_full_image_dict = {'focal': cfg.focal, 'princpt': princpt} + camera_to_render_template_in_pyrender = out['camera_to_render_template_in_pyrender'].cpu().numpy() + camera_pose_in_pyrender = out['camera_pose_in_pyrender'].cpu().numpy() + + + + if args.debug: + viz = [original_img] + + image_mesh = render_mesh(image, out['mesh_cam_render'][0].cpu().numpy(), face, + intrisics_full_image_dict, + color=color) + + viz.append((image_mesh * alpha + original_img * (1.0 - alpha)).astype(np.uint8)) + + image_template = render_mesh(image.copy(), model.module.template_mesh_cam_render[0].cpu().numpy(), face, + intrisics_full_image_dict, + color=color, cam_pose=camera_to_render_template_in_pyrender) + viz.append((image_template * alpha + original_img * (1.0 - alpha)).astype(np.uint8)) + + image_camera_rotate = render_mesh(image.copy(),model.module.template_mesh_cam_render[0].cpu().numpy(), face, + intrisics_full_image_dict, + color=color, cam_pose=camera_pose_in_pyrender) + viz.append((image_camera_rotate * alpha + original_img * (1.0 - alpha)).astype(np.uint8)) + + + viz_full_image = np.concatenate(viz, axis=0 if original_img.shape[1] > original_img.shape[0] else 1) + viz_full_image = cv2.resize(viz_full_image, (int( viz_full_image.shape[1] / viz_full_image.shape[0] * 853),853 )) + + + + + + + + crop_image_size = args.crop_image_size + crop_outputs = model.module.crop_and_process_camera_matrix(out, + original_img.copy(), + joint_2d= tmp_joint_img, # used to realign + crop_image_size=crop_image_size, + model_input_bbox=bbox, + head_bbox = head_bbox) + model_input_bbox = bbox.copy() + if len(crop_outputs) == 0: + continue + if len(crop_outputs) == 1: + save_keys = [f'{image_name}_{idx}.png'] + else: + save_keys = [f'{image_name}_{idx}_h.png' ,f'{image_name}_{idx}_s.png'] + + + for crop_idx in range(len(crop_outputs)): + crop_output = crop_outputs[crop_idx] + if crop_output is None: + continue + save_key = save_keys[crop_idx] + intrisics_crop = np.eye(4) + intrisics_crop[0, 0] = crop_output['intrisics']['focal'][0] + intrisics_crop[1, 1] = crop_output['intrisics']['focal'][1] + intrisics_crop[0, 2] = crop_output['intrisics']['princpt'][0] + intrisics_crop[1, 2] = crop_output['intrisics']['princpt'][1] + intrisics_crop_dict = {'focal': (intrisics_crop[0, 0], intrisics_crop[1, 1]), + 'princpt': [intrisics_crop[0, 2], intrisics_crop[1, 2]]} + + intrisics_standard = np.eye(4) + intrisics_standard[0, 0] = cfg.focal[0] + intrisics_standard[1, 1] = cfg.focal[1] + intrisics_standard[0, 2] = crop_image_size / 2 + intrisics_standard[1, 2] = crop_image_size / 2 + intrisics_standard_dict = {'focal': cfg.focal, 'princpt': [crop_image_size / 2, crop_image_size / 2]} + + normalized_camerapose_in_pyrender = out['normalized_camerapose_in_pyrender'] + normalized_transformation_in_realworld = out['normalized_transformation_in_realworld'] + camerapose_in_realworld = np.linalg.inv(normalized_transformation_in_realworld) + + + # realign image + + + viz = [crop_output['cropped_image']] + + # image_mesh = render_mesh(crop_output['cropped_image'], + # out['mesh_cam_render'][0].cpu().numpy(), face, + # intrisics_crop_dict, + # color=color) + image_mesh = render_mesh(crop_output['cropped_image'].copy(),out['neck_head_rotated_template_mesh'][0].cpu().numpy(), face, + intrisics_standard_dict, + color=color, + cam_pose=normalized_camerapose_in_pyrender) + + #image_mesh,_,_ = generate_patch_image(image_mesh, crop_output['bbox'], 1.0, 0.0, False, (crop_image_size,crop_image_size)) + viz.append((image_mesh * alpha + crop_output['cropped_image'] * (1.0-alpha)).astype(np.uint8)) + + # image_template = render_mesh(crop_output['cropped_image'].copy(), + # model.module.template_mesh_cam_render[0].cpu().numpy(), face, + # intrisics_crop_dict, + # color=color, + # cam_pose=camera_to_render_template_in_pyrender) + # viz.append((image_template * alpha + crop_output['cropped_image'] * (1.0-alpha)).astype(np.uint8)) + + # image_camera_rotate = render_mesh(crop_output['cropped_image'].copy(), + # model.module.template_mesh_cam_render[0].cpu().numpy(), face, + # intrisics_standard_dict, + # color=color, + # cam_pose=normalized_camerapose_in_pyrender) + # viz.append((image_camera_rotate * alpha + crop_output['cropped_image'] * (1.0-alpha)).astype(np.uint8)) + + + if args.debug: + projected_vertexes = model.module.get_projected_vertex(torch.from_numpy(template_mesh).float().cuda(), + intrisics_standard @ normalized_transformation_in_realworld) + + vertex_vis = crop_output['cropped_image'].copy() + + + camera_forward_direction = (camerapose_in_realworld[:3, :3] @ np.reshape(np.array([0, 0, 1]), (3, 1)))[:, 0] # 3,1 + camera_position = camerapose_in_realworld[:3, 3] # 3,1 + + not_pass_check = 0 + in_screen = 0 + for i in range(projected_vertexes.shape[0]): + if projected_vertexes[i, 0] < 0 or projected_vertexes[i, 0] >= vertex_vis.shape[1] or \ + projected_vertexes[i, 1] < 0 or projected_vertexes[i, 1] >= vertex_vis.shape[0]: + continue + # print(template_mesh[0, i, :].shape, camera_position.shape, camera_forward_direction.shape) + check = np.sum((template_mesh[0, i, :] - camera_position) * camera_forward_direction) + in_screen += 1 + if check < 0: + not_pass_check += 1 + cv2.circle(vertex_vis, (int(projected_vertexes[i, 0]), int(projected_vertexes[i, 1])), 5, (255, 255, 255), -1) + + viz.append(vertex_vis) + + if not_pass_check == in_screen: + raise Exception('all vertexes are before camera') + + + + # tmp_joint_img 19 x 2 + # rescale tmp_joint_img accroding to bbox + tmp_joint_img_on_croppped_image = tmp_joint_img.copy() + tmp_joint_img_on_croppped_image[:, 0] = tmp_joint_img[:, 0] - crop_output['bbox'][0] + tmp_joint_img_on_croppped_image[:, 1] = tmp_joint_img[:, 1] - crop_output['bbox'][1] + tmp_joint_img_on_croppped_image*= crop_image_size/crop_output['bbox'][2] + + skeleton_vis = vis_coco_skeleton(crop_output['cropped_image'].copy(), tmp_joint_img_on_croppped_image.T, vis_skeleton) + if len(head_bbox['bbox']) ==4 and crop_idx == 0: + tmp_head_bbox = np.array(head_bbox['bbox'].copy()) + tmp_head_bbox[0] = head_bbox['bbox'][0] - crop_output['bbox'][0] + tmp_head_bbox[1] = head_bbox['bbox'][1] - crop_output['bbox'][1] + tmp_head_bbox *= crop_image_size / crop_output['bbox'][2] + cv2.rectangle(skeleton_vis, (int(tmp_head_bbox[0]), int(tmp_head_bbox[1])), + (int(tmp_head_bbox[0] + tmp_head_bbox[2]), int(tmp_head_bbox[1] + tmp_head_bbox[3])), + (0, 255, 0), 4) + + viz.append(skeleton_vis) + + viz = np.concatenate(viz, axis=0 ) + if args.debug: + viz = cv2.resize(viz, (int(viz.shape[1]/viz.shape[0] * viz_full_image.shape[0]), viz_full_image.shape[0])) + viz = np.concatenate([viz_full_image, viz], axis=1) + else: + viz = cv2.resize(viz, (viz.shape[1]//6,viz.shape[0]//6)) + + cv2.imwrite(os.path.join(visualization_dir, save_key),viz) + + #''' + + cv2.imwrite(os.path.join(aligned_images_dir,save_key), crop_output['cropped_image']) + + # final ========================================= + res = { + 'bbox':crop_output['bbox'].tolist(), + + 'coco_joint': tmp_joint_img.tolist(), + 'model_input_bbox': model_input_bbox.tolist(), + 'raw_image_name': img_name, + + # real world + 'intrisics': intrisics_standard.tolist(), + 'intrisics_dict': intrisics_standard_dict, + 'world2camera_matrix': normalized_transformation_in_realworld.tolist(), + 'camera_pose': camerapose_in_realworld.tolist(), + + + # pyrender + # original + 'intrisics_full_image_dict': intrisics_full_image_dict, + 'camera_to_render_template_in_pyrender':camera_to_render_template_in_pyrender.tolist(), + 'camera_pose_in_pyrender':camera_pose_in_pyrender.tolist(), + + #crop + 'intrisics_crop_dict': intrisics_crop_dict, + 'normalized_camerapose_in_pyrender': normalized_camerapose_in_pyrender.tolist(), + + + + # smpl + 'smpl_pose': out['smpl_pose'].cpu().numpy().tolist(), + 'smpl_shape': out['smpl_shape'].cpu().numpy().tolist(), + 'cam_trans': out['cam_trans'].cpu().numpy().tolist(), + + + } + + result_json[save_key] = res + + result_count += 1 + + if result_count == 0: + print(f">>>>>>> No result in {img_path}!") + shutil.move(img_path, os.path.join(bad_images_dir, os.path.basename(img_path))) + # ============================================== + + +with open(result_json_path, 'w') as f: + json.dump(result_json, f) diff --git a/data_processing/demo/generate_visualization.py b/data_processing/demo/generate_visualization.py new file mode 100644 index 0000000..c51b5c9 --- /dev/null +++ b/data_processing/demo/generate_visualization.py @@ -0,0 +1,293 @@ +import glob +import sys +import os +import os.path as osp +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import pyrender + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +# from model import get_model +# from utils.preprocessing import process_bbox, generate_patch_image, get_bbox +# from utils.transforms import pixel2cam, cam2pixel, transform_joint_to_other_db +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton + +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL + + +def add_pelvis(joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + + +def add_neck(joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--input_dir', type=str, default='') + parser.add_argument('--output_dir', type=str, default='output') + parser.add_argument('--data_dir', type=str, default='101570') + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +def get_projected_vertex(mesh, world2screen_matrix): + mesh = mesh[0,...] + mesh = np.concatenate([mesh, np.ones((mesh.shape[0], 1))], axis=1) # 6890 x 4 + points_image = world2screen_matrix @ mesh.T # 4,6890 + points_image = points_image[:3, :] # 3,6890 + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + +def flip_yaw(pose_matrix): + flipped = pose_matrix.copy() + flipped[0, 1] *= -1 + flipped[0, 2] *= -1 + flipped[1, 0] *= -1 + flipped[2, 0] *= -1 + flipped[0, 3] *= -1 + return flipped +# argument parsing +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.set_data_dir(args.data_dir) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( + (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( + (0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), + (17, 19), + (19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), + (24, 26), + (25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face +alpha = 0.8 +# other joint set +coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( + (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), + (5, 6), + (11, 17), (12, 17), (17, 18)) + +vis_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ( + (0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), + (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +human_model_layer = smpl.layer['neutral'].cuda() + + +# prepare input image +transform = transforms.ToTensor() +pose2d_result_path = os.path.join(args.input_dir, '2d_pose_result_hrnet.json') +with open(pose2d_result_path) as f: + pose2d_result = json.load(f) + +img_dir = os.path.join(args.input_dir, 'images') + +debug = True + +output_dir = args.output_dir +print('>>>>>>> output_dir', output_dir) +os.makedirs(output_dir, exist_ok=True) +aligned_images_dir = os.path.join(output_dir, 'aligned_images') +visualization_dir = os.path.join(output_dir, 'visualization_debug') +os.makedirs(visualization_dir, exist_ok=True) + +result_json_path = os.path.join(output_dir, 'result.json') +with open(result_json_path, 'r') as f: + result_json = json.load(f) + +template_mesh_in_pyrender = np.load('./template_mesh_in_pyrender.npy') +print('template_mesh_in_pyrender.shape', template_mesh_in_pyrender.shape) +template_mesh = np.load('./template_mesh.npy') +print('template_mesh.shape', template_mesh.shape) + + + +from model import get_model +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + + +for aligned_image_name in result_json.keys(): + color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + meta_info = result_json[aligned_image_name] + + visualization_path = os.path.join(visualization_dir, aligned_image_name) + aligned_image_path = os.path.join(aligned_images_dir, aligned_image_name) + aligned_image = cv2.imread(aligned_image_path) + + + + # crop image + viz = [aligned_image] + mesh_cam_render, _ = human_model_layer(torch.from_numpy(np.array(meta_info['smpl_pose'])).float().cuda(), + torch.from_numpy(np.array(meta_info['smpl_shape'])).float().cuda(), + torch.from_numpy(np.array(meta_info['cam_trans'])).float().cuda()) + image_mesh = render_mesh(aligned_image.copy(), + mesh_cam_render[0].cpu().numpy(), face, + meta_info['intrisics_crop_dict'], + color=color) + # image_mesh,_,_ = generate_patch_image(image_mesh, crop_output['bbox'], 1.0, 0.0, False, (crop_image_size,crop_image_size)) + viz.append((image_mesh * alpha + aligned_image.copy() * (1.0 - alpha)).astype(np.uint8)) + + + image_camera_rotate = render_mesh(aligned_image.copy(), + template_mesh_in_pyrender[0], face, + meta_info['intrisics_dict'], + color=color, cam_pose=meta_info['normalized_camerapose_in_pyrender']) + viz.append((image_camera_rotate * alpha + aligned_image.copy() * (1.0-alpha)).astype(np.uint8)) + + + + + + # + projected_vertexes = get_projected_vertex(template_mesh, np.array(meta_info['intrisics']) @ np.array(meta_info['world2camera_matrix'])) + vertex_vis = aligned_image.copy() + camera_pose = np.array(meta_info['camera_pose']) + camera_forward_direction = (camera_pose[:3, :3] @ np.reshape(np.array([0, 0, 1]),(3,1)))[:,0] # 3,1 + camera_position = camera_pose[:, 3:4][:3, 0] # 34,1 + not_pass_check = 0 + in_screen = 0 + for i in range(projected_vertexes.shape[0]): + if projected_vertexes[i, 0] < 0 or projected_vertexes[i, 0] >= vertex_vis.shape[1] or \ + projected_vertexes[i, 1] < 0 or projected_vertexes[i, 1] >= vertex_vis.shape[0]: + continue + check = np.sum((template_mesh[0, i, :3] - camera_position) * camera_forward_direction) + in_screen += 1 + if check < 0: + not_pass_check += 1 + cv2.circle(vertex_vis, (int(projected_vertexes[i, 0]), int(projected_vertexes[i, 1])), 6, (255, 0, 0), -1) + print('check', not_pass_check, in_screen) + if not_pass_check == in_screen: + raise Exception('all vertexes are before camera') + viz.append(vertex_vis) + + # flip image + flip_camerapose_in_pyrender = np.array(meta_info['normalized_camerapose_in_pyrender']) + flip_camerapose_in_pyrender = flip_yaw(flip_camerapose_in_pyrender) + + image_camera_rotate_flip = render_mesh(cv2.flip(aligned_image.copy(), 1), + template_mesh_in_pyrender[0], face, + meta_info['intrisics_dict'], + color=color, cam_pose=flip_camerapose_in_pyrender) + viz.append((image_camera_rotate_flip * alpha + cv2.flip(aligned_image.copy(), 1) * (1.0 - alpha)).astype(np.uint8)) + + + # flip + camera_pose = np.array(meta_info['camera_pose']) + flip_camera_pose = flip_yaw(camera_pose) + + + flip_world2camera_matrix = np.linalg.inv(flip_camera_pose) + + projected_vertexes = get_projected_vertex(template_mesh, np.array(meta_info['intrisics']) @ flip_world2camera_matrix) # + # select head & neck vertexes + template_align_joint_coorinate = model.module.template_align_joint_coorinate.cpu().numpy() # 30, 6890 + template_mesh_cam_render = model.module.template_mesh_cam_render.cpu().numpy() # 6890, 3 + # template_mesh_cam_render -template_align_joint_coorinate > 0 + print(template_mesh_cam_render.shape) + selected_vertexes = np.where( template_mesh_cam_render[0,:,1]<0 )[0] + print(selected_vertexes.shape) + projected_vertexes = projected_vertexes[selected_vertexes, :] + print(projected_vertexes.shape) + + + vertex_vis = cv2.flip(aligned_image.copy(), 1) + camera_pose = flip_camera_pose + camera_forward_direction = (camera_pose[:3, :3] @ np.reshape(np.array([0, 0, 1]), (3, 1)))[:, 0] # 3,1 + camera_position = camera_pose[:, 3:4][:3, 0] # 34,1 + not_pass_check = 0 + in_screen = 0 + for i in range(projected_vertexes.shape[0]): + if projected_vertexes[i, 0] < 0 or projected_vertexes[i, 0] >= vertex_vis.shape[1] or \ + projected_vertexes[i, 1] < 0 or projected_vertexes[i, 1] >= vertex_vis.shape[0]: + continue + check = np.sum((template_mesh[0, i, :3] - camera_position) * camera_forward_direction) + in_screen += 1 + if check < 0: + not_pass_check += 1 + cv2.circle(vertex_vis, (int(projected_vertexes[i, 0]), int(projected_vertexes[i, 1])), 6, (255, 0, 0), -1) + print('check', not_pass_check, in_screen) + if not_pass_check == in_screen: + raise Exception('all vertexes are before camera') + + + + viz.append(vertex_vis) + viz = np.concatenate(viz, axis=0) + cv2.imwrite(visualization_path, cv2.resize(viz, (viz.shape[1] //4, viz.shape[0] //4))) diff --git a/data_processing/demo/my_demo.py b/data_processing/demo/my_demo.py new file mode 100644 index 0000000..6e4979b --- /dev/null +++ b/data_processing/demo/my_demo.py @@ -0,0 +1,293 @@ +import glob +import sys +import os +import os.path as osp +import argparse +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import matplotlib.pyplot as plt + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +from model import get_model +from utils.preprocessing import process_bbox, generate_patch_image, get_bbox +from utils.transforms import pixel2cam, cam2pixel, transform_joint_to_other_db +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton + +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL + + +def add_pelvis(joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + + +def add_neck(joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--img_name', type=str, default='101570') + parser.add_argument('--input_dir', type=str, default='101570') + parser.add_argument('--data_dir', type=str, default='101570') + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +# argument parsing +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.set_data_dir(args.data_dir) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( + (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( + (0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), + (17, 19), + (19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), + (24, 26), + (25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face + +# other joint set +coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( + (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), + (5, 6), + (11, 17), (12, 17), (17, 18)) + +vis_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ( + (0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), + (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +# snapshot load +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + +# prepare input image +transform = transforms.ToTensor() +pose2d_result_path = os.path.join(args.input_dir, '2d_pose_result_hrnet.json') +with open(pose2d_result_path) as f: + pose2d_result = json.load(f) + +img_dir = os.path.join(args.input_dir, 'images') + +output_dir = 'output' +os.makedirs(output_dir, exist_ok=True) +# img_name = args.img_name +# if img_name not in pose2d_result: +# print('missing pose2d! ') +# exit() +# +# img_path = osp.join(img_dir, img_name) +# + +for img_path in glob.glob(os.path.join(img_dir, "*")): + original_img = cv2.imread(img_path) + img_name = os.path.basename(img_path) + + if img_name.split('.')[0] not in ['arun-4ZpNFwSV7sY-unsplash','taylor-brandon-QAnqDU_fTz0-unsplash']: + continue + + + original_img_height, original_img_width = original_img.shape[:2] + coco_joint_list = pose2d_result[img_name] + + drawn_joints = [] + c = coco_joint_list + # manually assign the order of output meshes + # coco_joint_list = [c[2], c[0], c[1], c[4], c[3]] + + result_count = 0 + + + + + for idx in range(len(coco_joint_list)): + image = original_img.copy() + input = original_img.copy() + input2 = original_img.copy() + """ 2D pose input setting & hard-coding for filtering """ + pose_thr = 0.05 + coco_joint_img = np.asarray(coco_joint_list[idx])[:, :3] + coco_joint_img = add_pelvis(coco_joint_img, coco_joints_name) + coco_joint_img = add_neck(coco_joint_img, coco_joints_name) + coco_joint_valid = (coco_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + + # filter inaccurate inputs + det_score = sum(coco_joint_img[:, 2]) + if det_score < 0.3: + # print(f'det_score = ({det_score})!') + continue + if len(coco_joint_img[:, 2:].nonzero()[0]) < 1: + #print('len(coco_joint_img[:, 2:].nonzero()[0]) < 1!') + continue + # filter the same targets + tmp_joint_img = coco_joint_img.copy() + continue_check = False + for ddx in range(len(drawn_joints)): + drawn_joint_img = drawn_joints[ddx] + drawn_joint_val = (drawn_joint_img[:, 2].copy().reshape(-1, 1) > pose_thr).astype(np.float32) + diff = np.abs(tmp_joint_img[:, :2] - drawn_joint_img[:, :2]) * coco_joint_valid * drawn_joint_val + diff = diff[diff != 0] + if diff.size == 0: + continue_check = True + elif diff.mean() < 20: + continue_check = True + if continue_check: + #print('continue_check failed!') + continue + drawn_joints.append(tmp_joint_img) + + """ Prepare model input """ + # prepare bbox + # bbox = get_bbox(coco_joint_img, coco_joint_valid[:, 0]) # xmin, ymin, width, height + bbox = get_bbox(coco_joint_img, np.ones_like(coco_joint_valid[:, 0])) + + orig_bbox = bbox.copy() + bbox = process_bbox(bbox, original_img_width, original_img_height) + if bbox is None: + #print('bbox is None!') + continue + img, img2bb_trans, bb2img_trans = generate_patch_image(input2[:, :, ::-1], bbox, 1.0, 0.0, False, + cfg.input_img_shape) + img = transform(img.astype(np.float32)) / 255 + img = img.cuda()[None, :, :, :] + + coco_joint_img_xy1 = np.concatenate((coco_joint_img[:, :2], np.ones_like(coco_joint_img[:, :1])), 1) + coco_joint_img[:, :2] = np.dot(img2bb_trans, coco_joint_img_xy1.transpose(1, 0)).transpose(1, 0) + coco_joint_img[:, 0] = coco_joint_img[:, 0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + coco_joint_img[:, 1] = coco_joint_img[:, 1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + coco_joint_img = transform_joint_to_other_db(coco_joint_img, coco_joints_name, joints_name) + coco_joint_valid = transform_joint_to_other_db(coco_joint_valid, coco_joints_name, joints_name) + coco_joint_valid[coco_joint_img[:, 2] <= pose_thr] = 0 + + # check truncation + coco_joint_trunc = coco_joint_valid * ( + (coco_joint_img[:, 0] >= 0) * (coco_joint_img[:, 0] < cfg.output_hm_shape[2]) * ( + coco_joint_img[:, 1] >= 0) * (coco_joint_img[:, 1] < cfg.output_hm_shape[1])).reshape( + -1, 1).astype(np.float32) + coco_joint_img, coco_joint_trunc, bbox = torch.from_numpy(coco_joint_img).cuda()[None, :, :], torch.from_numpy( + coco_joint_trunc).cuda()[None, :, :], torch.from_numpy(bbox).cuda()[None, :] + + """ Model forward """ + inputs = {'img': img, 'joints': coco_joint_img, 'joints_mask': coco_joint_trunc} + targets = {} + meta_info = {'bbox': bbox} + with torch.no_grad(): + out = model(inputs, targets, meta_info, 'test') + + # draw output mesh + mesh_cam_render = out['mesh_cam_render'][0].cpu().numpy() + bbox = out['bbox'][0].cpu().numpy() + princpt = (bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2) + # original_img = vis_bbox(original_img, bbox, alpha=1) # for debug + + # generate random color + color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + image = render_mesh(image, mesh_cam_render, face, {'focal': cfg.focal, 'princpt': princpt}, + color=color) + + # img = img[0].cpu().numpy().transpose(1, 2, 0)[:, :, ::-1] * 255 + # img = render_mesh(img, out['mesh_cam_render_crop'][0].cpu().numpy(), face, {'focal': cfg.focal, 'princpt': cfg.princpt}, + # color=color) + # Save output mesh + + image_name = os.path.basename(img_path).split('.')[0] + file_name = f'{output_dir}/{image_name}_{idx}.jpg' + #print("file name: ", file_name) + + # save_obj(mesh_cam_render, face, file_name=f'{output_dir}/{image_name}_{idx}.obj') + + # cv2.imwrite(file_name, (image*0.7+original_img*0.3).astype(np.uint8)) + + # Draw input 2d pose + tmp_joint_img[-1], tmp_joint_img[-2] = tmp_joint_img[-2].copy(), tmp_joint_img[-1].copy() + input = vis_coco_skeleton(input, tmp_joint_img.T, vis_skeleton) + # cv2.imwrite(file_name[:-4] + f'_2dpose_{idx}.jpg', input) + input = cv2.rectangle(input, (int(orig_bbox[0]), int(orig_bbox[1])), + (int(orig_bbox[0] + orig_bbox[2]), int(orig_bbox[1] + orig_bbox[3])), (0, 0, 255), 2) + input = cv2.rectangle(input, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), + (0, 255, 0), 2) + + viz = np.concatenate([original_img, (image * 0.9 + original_img * 0.1).astype(np.uint8), input], + axis=0 if original_img.shape[1] > original_img.shape[0] else 1) + cv2.imwrite(file_name[:-4] + f'_viz_{idx}.jpg', viz) + + result_count += 1 + + + if result_count == 0: + print(f">>>>>>> No result in {img_path}!") + + + # cv2.imwrite(file_name[:-4] + f'_viz_crop_{idx}.jpg', img) + + + diff --git a/data_processing/demo/new_crop_use_densepose.py b/data_processing/demo/new_crop_use_densepose.py new file mode 100644 index 0000000..f591b1a --- /dev/null +++ b/data_processing/demo/new_crop_use_densepose.py @@ -0,0 +1,484 @@ +import glob +import shutil +import sys +import os +import os.path as osp +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import pyrender + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +# from model import get_model +# from utils.preprocessing import process_bbox, generate_patch_image, get_bbox +# from utils.transforms import pixel2cam, cam2pixel, transform_joint_to_other_db +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton + +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL + +import os +# os.environ["PYOPENGL_PLATFORM"] = "egl" +# check if on a Linux machine +if os.name == 'posix': # Linux + os.environ["PYOPENGL_PLATFORM"] = "egl" +def add_pelvis(joint_coord, joints_name): + lhip_idx = joints_name.index('L_Hip') + rhip_idx = joints_name.index('R_Hip') + pelvis = (joint_coord[lhip_idx, :] + joint_coord[rhip_idx, :]) * 0.5 + pelvis[2] = joint_coord[lhip_idx, 2] * joint_coord[rhip_idx, 2] # confidence for openpose + pelvis = pelvis.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, pelvis)) + + return joint_coord + + +def add_neck(joint_coord, joints_name): + lshoulder_idx = joints_name.index('L_Shoulder') + rshoulder_idx = joints_name.index('R_Shoulder') + neck = (joint_coord[lshoulder_idx, :] + joint_coord[rshoulder_idx, :]) * 0.5 + neck[2] = joint_coord[lshoulder_idx, 2] * joint_coord[rshoulder_idx, 2] + neck = neck.reshape(1, 3) + + joint_coord = np.concatenate((joint_coord, neck)) + + return joint_coord + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--input_dir', type=str, default='') + parser.add_argument('--output_dir', type=str, default='output') + parser.add_argument('--data_dir', type=str, default='101570') + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +def get_projected_vertex(mesh, world2screen_matrix): + mesh = mesh[0, ...] + mesh = np.concatenate([mesh, np.ones((mesh.shape[0], 1))], axis=1) # 6890 x 4 + # mesh = torch.cat([mesh, torch.ones((mesh.shape[0], 1))], dim=1) + points_image = world2screen_matrix @ mesh.T # 4,6890 + points_image = points_image[:3, :] # 3,6890 + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + + +def flip_yaw(pose_matrix): + flipped = pose_matrix.copy() + flipped[0, 1] *= -1 + flipped[0, 2] *= -1 + flipped[1, 0] *= -1 + flipped[2, 0] *= -1 + flipped[0, 3] *= -1 + return flipped + + +# argument parsing +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.set_data_dir(args.data_dir) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( + (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( + (0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), + (17, 19), + (19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), + (24, 26), + (25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face +alpha = 0.8 +# other joint set +coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( + (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), + (5, 6), + (11, 17), (12, 17), (17, 18)) + +vis_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ( + (0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), + (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +human_model_layer = smpl.layer['neutral'].cuda() + +# prepare input image +transform = transforms.ToTensor() +# pose2d_result_path = os.path.join(args.input_dir, '2d_pose_result_hrnet.json') +# with open(pose2d_result_path) as f: +# pose2d_result = json.load(f) + +img_dir = os.path.join(args.input_dir, 'images') + +debug = True + +input_aligned_images_dir = os.path.join(args.input_dir, 'aligned_images') + +output_dir = args.output_dir +print('>>>>>>> output_dir', output_dir) +os.makedirs(output_dir, exist_ok=True) + +result_json_path = os.path.join(args.input_dir, 'result.json') +with open(result_json_path, 'r') as f: + result_json = json.load(f) + +template_mesh_in_pyrender = np.load('./template_mesh_in_pyrender.npy') +print('template_mesh_in_pyrender.shape', template_mesh_in_pyrender.shape) +template_mesh = np.load('./template_mesh.npy') +print('template_mesh.shape', template_mesh.shape) + +from model import get_model + +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + +from utils.preprocessing import generate_patch_image + +output_aligned_images_dir = os.path.join(output_dir, 'aligned_images') +output_visualization_dir = os.path.join(output_dir, 'visualization') +os.makedirs(output_aligned_images_dir, exist_ok=True) +os.makedirs(output_visualization_dir, exist_ok=True) + +output_result_json_path = os.path.join(output_dir, 'result.json') +if os.path.exists(output_result_json_path): + with open(output_result_json_path, 'r') as f: + output_result_json = json.load(f) +else: + output_result_json = {} + + +def exit_function(): + global output_result_json + with open(output_result_json_path, 'w') as f: + json.dump(output_result_json, f, indent=4) + print('结束') + + +import atexit +from tqdm import tqdm + +atexit.register(exit_function) + +color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) +template_mesh_cam_render = model.module.template_mesh_cam_render.cpu().numpy() # 6890, 3 + +head_vertexes_index = np.where(template_mesh_cam_render[0, :, 1] < -0.0649)[0] + +# +# align_joint_name = 'Neck' +# align_joint_index = model.module.human_model.joints_name.index(align_joint_name) +# print(align_joint_indexmodel.module.joint_regressor.max(), model.module.joint_regressor.shape) +# neck_vertexes_index = np.where(model.module.joint_regressor[align_joint_index,:] == 1 )[0] +# print('neck_vertexes_index', neck_vertexes_index , neck_vertexes_index.shape) + +for input_aligned_image_path in tqdm(glob.glob(os.path.join(input_aligned_images_dir, '*'))): + aligned_image_name = osp.basename(input_aligned_image_path) + + dense_pose_img_paths = glob.glob(osp.join(args.input_dir, 'seg', aligned_image_name.replace('.png', '*.png'))) + original_aligned_image_path = osp.join(args.input_dir, 'aligned_images', aligned_image_name) + original_visualization_path = osp.join(args.input_dir, 'visualization', aligned_image_name) + + meta_info = result_json[aligned_image_name] + intrisics_dict = meta_info['intrisics_dict'] + intrisics_dict['focal'][0] = intrisics_dict['focal'][0] / 0.75 + intrisics_dict['focal'][1] = intrisics_dict['focal'][1] / 0.75 + + camera_pose = np.array(meta_info['camera_pose']) + + raw_image_name = meta_info['raw_image_name'] + raw_image_path = osp.join(img_dir, raw_image_name) + raw_image = cv2.imread(raw_image_path) + + if len(dense_pose_img_paths) == 0: + res_meta_info = {} + save_key = os.path.basename(input_aligned_image_path) + original_bbox = meta_info['bbox'] + + stride = original_bbox[2] + center = np.array([original_bbox[0] + stride / 2, original_bbox[1] + stride / 2]) + + stride = 0.75 * stride + + new_bbox = [center[0] - stride / 2, center[1] - stride / 2, stride, stride] + + new_aligned_image, img2bb_trans, bb2img_trans = generate_patch_image(raw_image, new_bbox, 1.0, 0.0, False, + (1024, 1024), + enable_padding=True) + viz = [new_aligned_image] + body_pose_param = torch.from_numpy(np.array(meta_info['smpl_pose'])).float().cuda() + body_pose_param = body_pose_param.reshape(-1, 24, 3) + body_pose_param = body_pose_param[:, 1:, :] + mesh_cam_render = model.module.get_neck_head_rotated_template_mesh(body_pose_param) + image_camera_rotate = render_mesh(new_aligned_image.copy(), + mesh_cam_render[0].cpu().numpy(), face, + intrisics_dict, + color=color, cam_pose=meta_info['normalized_camerapose_in_pyrender']) + viz.append((image_camera_rotate * alpha + new_aligned_image.copy() * (1.0 - alpha)).astype(np.uint8)) + + viz = np.concatenate(viz, axis=0).astype(np.uint8) + viz = cv2.resize(viz, (viz.shape[1] // 4, viz.shape[0] // 4)) + + # output_aligned_images_dir = os.path.join(output_dir, 'aligned_images') + # output_visualization_dir = os.path.join(output_dir, 'visualization') + new_aligned_image_path = os.path.join(output_aligned_images_dir, save_key) + cv2.imwrite(new_aligned_image_path, new_aligned_image) + + new_visualization_path = os.path.join(output_visualization_dir, save_key) + cv2.imwrite(new_visualization_path, viz) + + res_meta_info['bbox'] = new_bbox + res_meta_info['camera_pose'] = camera_pose.tolist() + res_meta_info['smpl_pose'] = meta_info['smpl_pose'] + res_meta_info['raw_image_name'] = raw_image_name + + output_result_json[save_key] = res_meta_info + + for dense_pose_img_path in dense_pose_img_paths: + res_meta_info = {} + save_key = os.path.basename(dense_pose_img_path) + if save_key in output_result_json: + continue + dense_pose_img_ = cv2.imread(dense_pose_img_path).astype(np.int32) + + dense_pose_index_ = dense_pose_img_[:, :, 0] * 255 + dense_pose_img_[:, :, 1] + dense_pose_index_[dense_pose_img_[:, :, 2] == 255] = -1 + + # print('dense_pose_index', dense_pose_index.shape, dense_pose_index.min(), dense_pose_index.max() ) + # mask out dense_pose_index that not in head_vertexes_index + # for i in range(dense_pose_index.shape[0]): + # for j in range(dense_pose_index.shape[1]): + # if dense_pose_index[i,j] not in head_vertexes_index: + # dense_pose_index[i,j] = -1 + # dense_pose_img[i,j,:] = 255 + # dense_pose_index = np.ones_like(dense_pose_index_)*-1 + # dense_pose_img = np.ones_like(dense_pose_img_)*255 + + dense_pose_2d_points = np.ones((head_vertexes_index.shape[0], 2)) * -1 + + for i, selected_vertex in enumerate(head_vertexes_index): + mask = dense_pose_index_ == selected_vertex + # dense_pose_index[mask] = selected_vertex + # dense_pose_img[mask, :] = dense_pose_img_[mask, :] + if mask.sum() == 0: + continue + dense_pose_2d_points[i, :] = np.array([np.mean(np.where(mask)[1]), np.mean(np.where(mask)[0])]) + # cv2.circle(dense_pose_img, (int(dense_pose_2d_points[i,0]), int(dense_pose_2d_points[i,1])), 6, (0,255,0), -1) + + # dense_pose_img = dense_pose_img.astype(np.uint8) + valid_head_vertexes_index = np.where(dense_pose_2d_points[:, 0] != -1)[0] + dense_pose_2d_points = dense_pose_2d_points[valid_head_vertexes_index, :] + + if dense_pose_2d_points.shape[0] == 0: + continue + + # project smpl mesh to img: + + # mesh_cam_render, _ = human_model_layer(torch.from_numpy(np.array(meta_info['smpl_pose'])).float().cuda(), + # torch.from_numpy(np.array(meta_info['smpl_shape'])).float().cuda(), + # torch.from_numpy(np.array(meta_info['cam_trans'])).float().cuda()) + # + # mesh_cam_render = mesh_cam_render.cpu().numpy() + body_pose_param = torch.from_numpy(np.array(meta_info['smpl_pose'])).float().cuda() + body_pose_param = body_pose_param.reshape(-1, 24, 3) + body_pose_param = body_pose_param[:, 1:, :] + mesh_cam_render = model.module.get_neck_head_rotated_template_mesh(body_pose_param) + + mesh_proj = torch.matmul(mesh_cam_render, model.module.template_mesh_R[:3, :3]).cpu().numpy() + + intrisics = np.array(meta_info['intrisics']) + # optimize trans and scale + transl = np.array([0, 0]).reshape(1, 2) + scale = np.array([1]).reshape(1, 1) + + proj_matrix = np.array(intrisics) @ np.array(meta_info['world2camera_matrix']) + projected_vertexes = get_projected_vertex(mesh_proj, proj_matrix) + moved_projected_vertexes = projected_vertexes * scale + transl + + projected_vertexes = moved_projected_vertexes[head_vertexes_index, :][valid_head_vertexes_index, :] + + # vertex_vis = dense_pose_img.copy() + # for i in range(projected_vertexes.shape[0]): + # if projected_vertexes[i, 0] < 0 or projected_vertexes[i, 0] >= vertex_vis.shape[1] or \ + # projected_vertexes[i, 1] < 0 or projected_vertexes[i, 1] >= vertex_vis.shape[0]: + # continue + # cv2.circle(vertex_vis, (int(projected_vertexes[i, 0]), int(projected_vertexes[i, 1])), 6, (255, 0, 0), -1) + # cv2.imshow('vertex_vis', vertex_vis) + # cv2.waitKey(0) + + # print('dense_pose_2d_points', dense_pose_2d_points.shape) + # print('projected_vertexes', projected_vertexes.shape) + # try to align projected_vertexes to dense_pose_index + height_dense_pose = dense_pose_2d_points[:, 1].max() - dense_pose_2d_points[:, 1].min() + width_dense_pose = dense_pose_2d_points[:, 0].max() - dense_pose_2d_points[:, 0].min() + new_center = np.array([1024 / 2, 1024 / 2]).reshape(1, 2) + + height_projected_vertexes = projected_vertexes[:, 1].max() - projected_vertexes[:, 1].min() + width_projected_vertexes = projected_vertexes[:, 0].max() - projected_vertexes[:, 0].min() + + scale = max(height_projected_vertexes / height_dense_pose, width_projected_vertexes / width_dense_pose) + + scale = max(scale, 0.85) + scale = min(scale, 2) + + dense_pose_2d_points = dense_pose_2d_points * scale + new_center = new_center * scale + + center_dense_pose = np.array([dense_pose_2d_points[:, 0].mean(), dense_pose_2d_points[:, 1].mean()]).reshape(1, + 2) + center_projected_vertexes = np.array( + [projected_vertexes[:, 0].mean(), projected_vertexes[:, 1].mean()]).reshape(1, 2) + + transl = center_projected_vertexes - center_dense_pose + + dense_pose_2d_points = dense_pose_2d_points + transl + new_center = new_center + transl + + # vertex_vis = np.ones_like(dense_pose_img)*255 + # for i in range(projected_vertexes.shape[0]): + # if projected_vertexes[i, 0] < 0 or projected_vertexes[i, 0] >= vertex_vis.shape[1] or \ + # projected_vertexes[i, 1] < 0 or projected_vertexes[i, 1] >= vertex_vis.shape[0]: + # continue + # cv2.circle(vertex_vis, (int(projected_vertexes[i, 0]), int(projected_vertexes[i, 1])), 6, (255, 0, 0), -1) + # cv2.circle(vertex_vis, (int(dense_pose_2d_points[i, 0]), int(dense_pose_2d_points[i, 1])), 6, (0, 255, 0), -1) + # + # cv2.imshow('vertex_vis', vertex_vis) + # cv2.waitKey(0) + + original_bbox = meta_info['bbox'] + stride = original_bbox[2] + center = np.array([original_bbox[0] + stride / 2, original_bbox[1] + stride / 2]) + + new_stride = stride / scale + + # transl_ = transl/1024 + + # new_center_on_crop_image = np.array([1024/2, 1024/2]) + transl + new_center = center - (new_center - np.array([1024 / 2, 1024 / 2]).reshape(1, 2)) / 1024 * stride + new_center = new_center.reshape(-1) + # print('new_center', new_center) + # print('new_stride', new_stride) + new_stride = new_stride * 0.75 + new_bbox = [new_center[0] - new_stride / 2, new_center[1] - new_stride / 2, new_stride, new_stride] + meta_info['bbox'] = new_bbox + + try: + new_aligned_image, img2bb_trans, bb2img_trans = generate_patch_image(raw_image, new_bbox, 1.0, 0.0, False, + (1024, 1024), + enable_padding=True) + except: + continue + viz = [new_aligned_image] + image_camera_rotate = render_mesh(new_aligned_image.copy(), + mesh_cam_render[0].cpu().numpy(), face, + intrisics_dict, + color=color, cam_pose=meta_info['normalized_camerapose_in_pyrender']) + viz.append((image_camera_rotate * alpha + new_aligned_image.copy() * (1.0 - alpha)).astype(np.uint8)) + + viz = np.concatenate(viz, axis=0).astype(np.uint8) + viz = cv2.resize(viz, (viz.shape[1] // 4, viz.shape[0] // 4)) + + # output_aligned_images_dir = os.path.join(output_dir, 'aligned_images') + # output_visualization_dir = os.path.join(output_dir, 'visualization') + new_aligned_image_path = os.path.join(output_aligned_images_dir, save_key) + cv2.imwrite(new_aligned_image_path, new_aligned_image) + + new_visualization_path = os.path.join(output_visualization_dir, save_key) + cv2.imwrite(new_visualization_path, viz) + + res_meta_info['bbox'] = new_bbox + res_meta_info['camera_pose'] = camera_pose.tolist() + res_meta_info['smpl_pose'] = meta_info['smpl_pose'] + res_meta_info['raw_image_name'] = raw_image_name + + output_result_json[save_key] = res_meta_info + + # print(save_key,scale) + +with open(output_result_json_path, 'w') as f: + json.dump(output_result_json, f, indent=4) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/data_processing/demo/remove_bad_vis.py b/data_processing/demo/remove_bad_vis.py new file mode 100644 index 0000000..c13afda --- /dev/null +++ b/data_processing/demo/remove_bad_vis.py @@ -0,0 +1,252 @@ +import glob +import sys +import os +import os.path as osp +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import pyrender +import glob +import sys +import os +import os.path as osp +import argparse + +import matplotlib.pyplot as plt +import numpy as np +import cv2 +import colorsys +import json +import random +import torch +import torchvision.transforms as transforms +from torch.nn.parallel.data_parallel import DataParallel +import torch.backends.cudnn as cudnn +import pyrender + +sys.path.insert(0, osp.join('..', 'main')) +sys.path.insert(0, osp.join('..', 'data')) +sys.path.insert(0, osp.join('..', 'common')) +from config import cfg +from tqdm import tqdm +from utils.vis import vis_mesh, save_obj, render_mesh, vis_coco_skeleton + +sys.path.insert(0, cfg.smpl_path) +from utils.smpl import SMPL +from model import get_model + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids', default='0') + parser.add_argument('--model_path', type=str, default='demo_checkpoint.pth.tar') + parser.add_argument('--input_dir', type=str, default='') + + parser.add_argument('--data_dir', type=str, + default='E:\project/3DCrowdNet_upper_body-main\data') + + args = parser.parse_args() + + # test gpus + if not args.gpu_ids: + assert 0, print("Please set proper gpu ids") + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +args = parse_args() +cfg.set_args(args.gpu_ids, is_test=True) +cfg.set_data_dir(args.data_dir) +cfg.render = True +cudnn.benchmark = True + +# SMPL joint set +joint_num = 30 # original: 24. manually add nose, L/R eye, L/R ear, head top +joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') +flip_pairs = ( + (1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) +skeleton = ( + (0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), (9, 14), (14, 17), + (17, 19), + (19, 21), (21, 23), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 12), (12, 24), (24, 15), (24, 25), + (24, 26), + (25, 27), (26, 28), (24, 29)) + +# SMPl mesh +vertex_num = 6890 +smpl = SMPL() +face = smpl.face +alpha = 0.9 +# other joint set +coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Pelvis', 'Neck') +coco_skeleton = ( + (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), + (5, 6), + (11, 17), (12, 17), (17, 18)) + +vis_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', + 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle', 'Thorax', 'Pelvis') +vis_skeleton = ( + (0, 1), (0, 2), (2, 4), (1, 3), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 17), (6, 17), (11, 18), + (12, 18), (17, 18), (17, 0), (6, 8), (8, 10),) + +human_model_layer = smpl.layer['neutral'].cuda() + +# snapshot load +model_path = args.model_path +assert osp.exists(model_path), 'Cannot find model at ' + model_path +print('Load checkpoint from {}'.format(model_path)) +model = get_model(vertex_num, joint_num, 'test') + +model = DataParallel(model).cuda() +ckpt = torch.load(model_path) +model.load_state_dict(ckpt['network'], strict=False) +model.eval() + + +def get_projected_vertex(mesh, world2screen_matrix): + mesh = mesh[0, ...] + mesh = np.concatenate([mesh, np.ones((mesh.shape[0], 1))], axis=1) # 6890 x 4 + points_image = world2screen_matrix @ mesh.T # 4,6890 + points_image = points_image[:3, :] # 3,6890 + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + + +import shutil + +path = args.input_dir +bad_path = os.path.join(args.input_dir,'bad_aligned_images') +os.makedirs(bad_path, exist_ok=True) +image_list = glob.glob(os.path.join(args.input_dir,'aligned_images/*')) # + +result_json_path = os.path.join(path, 'result.json') +with open(result_json_path, 'r') as f: + result_json = json.load(f) + +for image_path in tqdm(image_list): + aligned_image_name = os.path.basename(image_path) + + vis1_path = os.path.join(path, 'aligned_images', aligned_image_name) + + + meta_info = result_json[aligned_image_name] + bbox1 = meta_info['bbox'] + + coco_joint = np.array(meta_info['coco_joint']) + coco_joint1 = coco_joint.copy() + coco_joint1[:, 0] = coco_joint1[:, 0] - bbox1[0] + coco_joint1[:, 1] = coco_joint1[:, 1] - bbox1[1] + coco_joint1 *= 1024 / bbox1[2] + + # vis1 = cv2.imread(image_path) + + pose_params_input = torch.from_numpy(np.array(meta_info['smpl_pose'])).float().cuda().view(1, 24, 3) + pose_params_input = pose_params_input[:, 1:, :] + + joints_3d = model.module.get_neck_head_rotated_template_mesh_joint(pose_params_input).cpu().numpy() + # print(joints_3d.shape) + projected_joints = get_projected_vertex(joints_3d, np.array(meta_info['intrisics']) @ np.array( + meta_info['world2camera_matrix'])) + # print(projected_joints.shape) + + vis1 = cv2.imread(image_path) + # joint_names = ['Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear'] + # for j in range(5): + # cv2.circle(vis1, (int(coco_joint1[j, 0]), int(coco_joint1[j, 1])), 3, (0, 0, 255), -1) + # #cv2.circle(vis2, (int(coco_joint2[j, 0]), int(coco_joint2[j, 1])), 3, (0, 0, 255), -1) + # cv2.putText(vis1, joint_names[j], (int(coco_joint1[j, 0]), int(coco_joint1[j, 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + # #cv2.putText(vis2, joint_names[j], (int(coco_joint2[j, 0]), int(coco_joint2[j, 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + # + # for j in range(24,29): + # cv2.circle(vis1, (int(projected_joints[j, 0]), int(projected_joints[j, 1])), 3, (0, 255, 0), -1) + # #cv2.circle(vis2, (int(projected_joints[j, 0]), int(projected_joints[j, 1])), 3, (0, 255, 0), -1) + # cv2.putText(vis1, joint_names[j-24], (int(projected_joints[j, 0]), int(projected_joints[j, 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + # #cv2.putText(vis2, joint_names[j-24], (int(projected_joints[j, 0]), int(projected_joints[j, 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', + 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle') + + smpl_joints_name = ( + 'Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', + 'Neck', 'L_Thorax', 'R_Thorax', + 'Head', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand', 'Nose', 'L_Eye', + 'R_Eye', 'L_Ear', 'R_Ear', 'Head_top') + + selected_joint_name = ('Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder') + + selected_coco_idx = [coco_joints_name.index(joint_name) for joint_name in selected_joint_name] + selected_smpl_idx = [smpl_joints_name.index(joint_name) for joint_name in selected_joint_name] + + distance = 0 + count = 0 + for i in range(len(selected_joint_name)): + if coco_joint1[selected_coco_idx[i], 2] > 0.1: + distance += np.linalg.norm(coco_joint1[selected_coco_idx[i], :2] - projected_joints[selected_smpl_idx[i], :]) + count += 1 + + cv2.circle(vis1, (int(coco_joint1[selected_coco_idx[i], 0]), int(coco_joint1[selected_coco_idx[i], 1])), 3, (0, 0, 255), -1) + cv2.putText(vis1, selected_joint_name[i], (int(coco_joint1[selected_coco_idx[i], 0]), int(coco_joint1[selected_coco_idx[i], 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + cv2.circle(vis1, (int(projected_joints[selected_smpl_idx[i], 0]), int(projected_joints[selected_smpl_idx[i], 1])), 3, (0, 255, 0), -1) + cv2.putText(vis1, selected_joint_name[i], (int(projected_joints[selected_smpl_idx[i], 0]), int(projected_joints[selected_smpl_idx[i], 1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + if count > 0: + distance /= count + else: + shutil.move(vis1_path, os.path.join(bad_path, os.path.basename(vis1_path))) + + if distance > 150: + shutil.move(vis1_path, os.path.join(bad_path, os.path.basename(vis1_path))) + + + + + + + + # cv2.imshow('vis1', vis1) + # cv2.waitKey(0) + # cv2.imshow('vis2', vis2) + # cv2.waitKey(0) + # exit() + + # distance1 = np.linalg.norm(coco_joint1[:5, :2] - projected_joints[24:29, :]) + # + # # if distance1 > distance2: + # # print('move', vis1_path, 'to', os.path.join(bad_path, os.path.basename(vis1_path))) + # # shutil.move(vis1_path, os.path.join(bad_path, os.path.basename(vis1_path))) + # # else: + # # print('move', vis2_path, 'to', os.path.join(bad_path, os.path.basename(vis2_path))) + # # shutil.move(vis2_path, os.path.join(bad_path, os.path.basename(vis2_path))) + # + # if distance1 > 50: + # cv2.imshow('vis1', vis1) + # cv2.waitKey(0) + + # exit() diff --git a/data_processing/demo/template_mesh.npy b/data_processing/demo/template_mesh.npy new file mode 100644 index 0000000..f132f73 Binary files /dev/null and b/data_processing/demo/template_mesh.npy differ diff --git a/data_processing/demo/template_mesh_in_pyrender.npy b/data_processing/demo/template_mesh_in_pyrender.npy new file mode 100644 index 0000000..5ec3e50 Binary files /dev/null and b/data_processing/demo/template_mesh_in_pyrender.npy differ diff --git a/data_processing/detectron2/.clang-format b/data_processing/detectron2/.clang-format new file mode 100644 index 0000000..39b1b3d --- /dev/null +++ b/data_processing/detectron2/.clang-format @@ -0,0 +1,85 @@ +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never diff --git a/data_processing/detectron2/.flake8 b/data_processing/detectron2/.flake8 new file mode 100644 index 0000000..28881e4 --- /dev/null +++ b/data_processing/detectron2/.flake8 @@ -0,0 +1,15 @@ +# This is an example .flake8 config, used when developing *Black* itself. +# Keep in sync with setup.cfg which is used for source packages. + +[flake8] +ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 +max-line-length = 100 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 +exclude = build +per-file-ignores = + **/__init__.py:F401,F403,E402 + **/configs/**.py:F401,E402 + configs/**.py:F401,E402 + **/tests/config/**.py:F401,E402 + tests/config/**.py:F401,E402 diff --git a/data_processing/detectron2/.gitignore b/data_processing/detectron2/.gitignore new file mode 100644 index 0000000..9953d9b --- /dev/null +++ b/data_processing/detectron2/.gitignore @@ -0,0 +1,53 @@ +# output dir +output +instant_test_output +inference_test_output + + +*.png +*.json +*.diff +*.jpg +!/projects/DensePose/doc/images/*.jpg + +# compilation and distribution +__pycache__ +_ext +*.pyc +*.pyd +*.so +*.dll +*.egg-info/ +build/ +dist/ +wheels/ + +# pytorch/python/numpy formats +*.pth +*.pkl +*.npy +*.ts +model_ts*.txt + +# ipython/jupyter notebooks +*.ipynb +**/.ipynb_checkpoints/ + +# Editor temporaries +*.swn +*.swo +*.swp +*~ + +# editor settings +.idea +.vscode +_darcs + +# project dirs +/detectron2/model_zoo/configs +/datasets/* +!/datasets/*.* +/projects/*/datasets +/models +/snippet diff --git a/data_processing/detectron2/GETTING_STARTED.md b/data_processing/detectron2/GETTING_STARTED.md new file mode 100644 index 0000000..404b0c8 --- /dev/null +++ b/data_processing/detectron2/GETTING_STARTED.md @@ -0,0 +1,79 @@ +## Getting Started with Detectron2 + +This document provides a brief intro of the usage of builtin command-line tools in detectron2. + +For a tutorial that involves actual coding with the API, +see our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) +which covers how to run inference with an +existing model, and how to train a builtin model on a custom dataset. + + +### Inference Demo with Pre-trained Models + +1. Pick a model and its config file from + [model zoo](MODEL_ZOO.md), + for example, `mask_rcnn_R_50_FPN_3x.yaml`. +2. We provide `demo.py` that is able to demo builtin configs. Run it with: +``` +cd demo/ +python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --input input1.jpg input2.jpg \ + [--other-options] + --opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl +``` +The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation. +This command will run the inference and show visualizations in an OpenCV window. + +For details of the command line arguments, see `demo.py -h` or look at its source code +to understand its behavior. Some common arguments are: +* To run __on your webcam__, replace `--input files` with `--webcam`. +* To run __on a video__, replace `--input files` with `--video-input video.mp4`. +* To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`. +* To save outputs to a directory (for images) or a file (for webcam or video), use `--output`. + + +### Training & Evaluation in Command Line + +We provide two scripts in "tools/plain_train_net.py" and "tools/train_net.py", +that are made to train all the configs provided in detectron2. You may want to +use it as a reference to write your own training script. + +Compared to "train_net.py", "plain_train_net.py" supports fewer default +features. It also includes fewer abstraction, therefore is easier to add custom +logic. + +To train a model with "train_net.py", first +setup the corresponding datasets following +[datasets/README.md](./datasets/README.md), +then run: +``` +cd tools/ +./train_net.py --num-gpus 8 \ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml +``` + +The configs are made for 8-GPU training. +To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.: +``` +./train_net.py \ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \ + --num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 +``` + +To evaluate a model's performance, use +``` +./train_net.py \ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \ + --eval-only MODEL.WEIGHTS /path/to/checkpoint_file +``` +For more options, see `./train_net.py -h`. + +### Use Detectron2 APIs in Your Code + +See our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) +to learn how to use detectron2 APIs to: +1. run inference with an existing model +2. train a builtin model on a custom dataset + +See [detectron2/projects](https://github.com/facebookresearch/detectron2/tree/main/projects) +for more ways to build your project on detectron2. diff --git a/data_processing/detectron2/INSTALL.md b/data_processing/detectron2/INSTALL.md new file mode 100644 index 0000000..f522e6f --- /dev/null +++ b/data_processing/detectron2/INSTALL.md @@ -0,0 +1,261 @@ +## Installation + +### Requirements +- Linux or macOS with Python ≥ 3.7 +- PyTorch ≥ 1.8 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. + Install them together at [pytorch.org](https://pytorch.org) to make sure of this +- OpenCV is optional but needed by demo and visualization + + +### Build Detectron2 from Source + +gcc & g++ ≥ 5.4 are required. [ninja](https://ninja-build.org/) is optional but recommended for faster build. +After having them, run: +``` +python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' +# (add --user if you don't have permission) + +# Or, to install it from a local clone: +git clone https://github.com/facebookresearch/detectron2.git +python -m pip install -e detectron2 + +# On macOS, you may need to prepend the above commands with a few environment variables: +CC=clang CXX=clang++ ARCHFLAGS="-arch x86_64" python -m pip install ... +``` + +To __rebuild__ detectron2 that's built from a local clone, use `rm -rf build/ **/*.so` to clean the +old build first. You often need to rebuild detectron2 after reinstalling PyTorch. + +### Install Pre-Built Detectron2 (Linux only) + +Choose from this table to install [v0.6 (Oct 2021)](https://github.com/facebookresearch/detectron2/releases): + +
CUDA torch 1.10torch 1.9torch 1.8
11.3
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
+
11.1
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.8/index.html
+
10.2
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.10/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.8/index.html
+
10.1
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html
+
cpu
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html
+
install
python -m pip install detectron2 -f \
+  https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.8/index.html
+
+ +Note that: +1. The pre-built packages have to be used with corresponding version of CUDA and the official package of PyTorch. + Otherwise, please build detectron2 from source. +2. New packages are released every few months. Therefore, packages may not contain latest features in the main + branch and may not be compatible with the main branch of a research project that uses detectron2 + (e.g. those in [projects](projects)). + +### Common Installation Issues + +Click each issue for its solutions: + +
+ +Undefined symbols that looks like "TH..","at::Tensor...","torch..." + +
+ +This usually happens when detectron2 or torchvision is not +compiled with the version of PyTorch you're running. + +If the error comes from a pre-built torchvision, uninstall torchvision and pytorch and reinstall them +following [pytorch.org](http://pytorch.org). So the versions will match. + +If the error comes from a pre-built detectron2, check [release notes](https://github.com/facebookresearch/detectron2/releases), +uninstall and reinstall the correct pre-built detectron2 that matches pytorch version. + +If the error comes from detectron2 or torchvision that you built manually from source, +remove files you built (`build/`, `**/*.so`) and rebuild it so it can pick up the version of pytorch currently in your environment. + +If the above instructions do not resolve this problem, please provide an environment (e.g. a dockerfile) that can reproduce the issue. +
+ +
+ +Missing torch dynamic libraries, OR segmentation fault immediately when using detectron2. + +This usually happens when detectron2 or torchvision is not +compiled with the version of PyTorch you're running. See the previous common issue for the solution. +
+ +
+ +Undefined C++ symbols (e.g. "GLIBCXX..") or C++ symbols not found. + +
+Usually it's because the library is compiled with a newer C++ compiler but run with an old C++ runtime. + +This often happens with old anaconda. +It may help to run `conda update libgcc` to upgrade its runtime. + +The fundamental solution is to avoid the mismatch, either by compiling using older version of C++ +compiler, or run the code with proper C++ runtime. +To run the code with a specific C++ runtime, you can use environment variable `LD_PRELOAD=/path/to/libstdc++.so`. + +
+ +
+ +"nvcc not found" or "Not compiled with GPU support" or "Detectron2 CUDA Compiler: not available". + +
+CUDA is not found when building detectron2. +You should make sure + +``` +python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)' +``` + +print `(True, a directory with cuda)` at the time you build detectron2. + +Most models can run inference (but not training) without GPU support. To use CPUs, set `MODEL.DEVICE='cpu'` in the config. +
+ +
+ +"invalid device function" or "no kernel image is available for execution". + +
+Two possibilities: + +* You build detectron2 with one version of CUDA but run it with a different version. + + To check whether it is the case, + use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions. + In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA" + to contain cuda libraries of the same version. + + When they are inconsistent, + you need to either install a different build of PyTorch (or build by yourself) + to match your local CUDA installation, or install a different version of CUDA to match PyTorch. + +* PyTorch/torchvision/Detectron2 is not built for the correct GPU SM architecture (aka. compute capability). + + The architecture included by PyTorch/detectron2/torchvision is available in the "architecture flags" in + `python -m detectron2.utils.collect_env`. It must include + the architecture of your GPU, which can be found at [developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus). + + If you're using pre-built PyTorch/detectron2/torchvision, they have included support for most popular GPUs already. + If not supported, you need to build them from source. + + When building detectron2/torchvision from source, they detect the GPU device and build for only the device. + This means the compiled code may not work on a different GPU device. + To recompile them for the correct architecture, remove all installed/compiled files, + and rebuild them with the `TORCH_CUDA_ARCH_LIST` environment variable set properly. + For example, `export TORCH_CUDA_ARCH_LIST="6.0;7.0"` makes it compile for both P100s and V100s. +
+ +
+ +Undefined CUDA symbols; Cannot open libcudart.so + +
+The version of NVCC you use to build detectron2 or torchvision does +not match the version of CUDA you are running with. +This often happens when using anaconda's CUDA runtime. + +Use `python -m detectron2.utils.collect_env` to find out inconsistent CUDA versions. +In the output of this command, you should expect "Detectron2 CUDA Compiler", "CUDA_HOME", "PyTorch built with - CUDA" +to contain cuda libraries of the same version. + +When they are inconsistent, +you need to either install a different build of PyTorch (or build by yourself) +to match your local CUDA installation, or install a different version of CUDA to match PyTorch. +
+ + +
+ +C++ compilation errors from NVCC / NVRTC, or "Unsupported gpu architecture" + +
+A few possibilities: + +1. Local CUDA/NVCC version has to match the CUDA version of your PyTorch. Both can be found in `python collect_env.py` + (download from [here](./detectron2/utils/collect_env.py)). + When they are inconsistent, you need to either install a different build of PyTorch (or build by yourself) + to match your local CUDA installation, or install a different version of CUDA to match PyTorch. + +2. Local CUDA/NVCC version shall support the SM architecture (a.k.a. compute capability) of your GPU. + The capability of your GPU can be found at [developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus). + The capability supported by NVCC is listed at [here](https://gist.github.com/ax3l/9489132). + If your NVCC version is too old, this can be workaround by setting environment variable + `TORCH_CUDA_ARCH_LIST` to a lower, supported capability. + +3. The combination of NVCC and GCC you use is incompatible. You need to change one of their versions. + See [here](https://gist.github.com/ax3l/9489132) for some valid combinations. + Notably, CUDA<=10.1.105 doesn't support GCC>7.3. + + The CUDA/GCC version used by PyTorch can be found by `print(torch.__config__.show())`. + +
+ + +
+ +"ImportError: cannot import name '_C'". + +
+Please build and install detectron2 following the instructions above. + +Or, if you are running code from detectron2's root directory, `cd` to a different one. +Otherwise you may not import the code that you installed. +
+ + +
+ +Any issue on windows. + +
+ +Detectron2 is continuously built on windows with [CircleCI](https://app.circleci.com/pipelines/github/facebookresearch/detectron2?branch=main). +However we do not provide official support for it. +PRs that improves code compatibility on windows are welcome. +
+ +
+ +ONNX conversion segfault after some "TraceWarning". + +
+The ONNX package is compiled with a too old compiler. + +Please build and install ONNX from its source code using a compiler +whose version is closer to what's used by PyTorch (available in `torch.__config__.show()`). +
+ + +
+ +"library not found for -lstdc++" on older version of MacOS + +
+ +See [this stackoverflow answer](https://stackoverflow.com/questions/56083725/macos-build-issues-lstdc-not-found-while-building-python-package). + +
+ + +### Installation inside specific environments: + +* __Colab__: see our [Colab Tutorial](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) + which has step-by-step instructions. + +* __Docker__: The official [Dockerfile](docker) installs detectron2 with a few simple commands. diff --git a/data_processing/detectron2/LICENSE b/data_processing/detectron2/LICENSE new file mode 100644 index 0000000..cd1b070 --- /dev/null +++ b/data_processing/detectron2/LICENSE @@ -0,0 +1,202 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/data_processing/detectron2/MODEL_ZOO.md b/data_processing/detectron2/MODEL_ZOO.md new file mode 100644 index 0000000..69db272 --- /dev/null +++ b/data_processing/detectron2/MODEL_ZOO.md @@ -0,0 +1,1052 @@ +# Detectron2 Model Zoo and Baselines + +## Introduction + +This file documents a large collection of baselines trained +with detectron2 in Sep-Oct, 2019. +All numbers were obtained on [Big Basin](https://engineering.fb.com/data-center-engineering/introducing-big-basin-our-next-generation-ai-hardware/) +servers with 8 NVIDIA V100 GPUs & NVLink. The speed numbers are periodically updated with latest PyTorch/CUDA/cuDNN versions. +You can access these models from code using [detectron2.model_zoo](https://detectron2.readthedocs.io/modules/model_zoo.html) APIs. + +In addition to these official baseline models, you can find more models in [projects/](projects/). + +#### How to Read the Tables +* The "Name" column contains a link to the config file. Models can be reproduced using `tools/train_net.py` with the corresponding yaml config file, + or `tools/lazyconfig_train_net.py` for python config files. +* Training speed is averaged across the entire training. + We keep updating the speed with latest version of detectron2/pytorch/etc., + so they might be different from the `metrics` file. + Training speed for multi-machine jobs is not provided. +* Inference speed is measured by `tools/train_net.py --eval-only`, or [inference_on_dataset()](https://detectron2.readthedocs.io/modules/evaluation.html#detectron2.evaluation.inference_on_dataset), + with batch size 1 in detectron2 directly. + Measuring it with custom code may introduce other overhead. + Actual deployment in production should in general be faster than the given inference + speed due to more optimizations. +* The *model id* column is provided for ease of reference. + To check downloaded file integrity, any model on this page contains its md5 prefix in its file name. +* Training curves and other statistics can be found in `metrics` for each model. + +#### Common Settings for COCO Models +* All COCO models were trained on `train2017` and evaluated on `val2017`. +* The default settings are __not directly comparable__ with Detectron's standard settings. + For example, our default training data augmentation uses scale jittering in addition to horizontal flipping. + + To make fair comparisons with Detectron's settings, see + [Detectron1-Comparisons](configs/Detectron1-Comparisons/) for accuracy comparison, + and [benchmarks](https://detectron2.readthedocs.io/notes/benchmarks.html) + for speed comparison. +* For Faster/Mask R-CNN, we provide baselines based on __3 different backbone combinations__: + * __FPN__: Use a ResNet+FPN backbone with standard conv and FC heads for mask and box prediction, + respectively. It obtains the best + speed/accuracy tradeoff, but the other two are still useful for research. + * __C4__: Use a ResNet conv4 backbone with conv5 head. The original baseline in the Faster R-CNN paper. + * __DC5__ (Dilated-C5): Use a ResNet conv5 backbone with dilations in conv5, and standard conv and FC heads + for mask and box prediction, respectively. + This is used by the Deformable ConvNet paper. +* Most models are trained with the 3x schedule (~37 COCO epochs). + Although 1x models are heavily under-trained, we provide some ResNet-50 models with the 1x (~12 COCO epochs) + training schedule for comparison when doing quick research iteration. + +#### ImageNet Pretrained Models + +It's common to initialize from backbone models pre-trained on ImageNet classification tasks. The following backbone models are available: + +* [R-50.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-50.pkl): converted copy of [MSRA's original ResNet-50](https://github.com/KaimingHe/deep-residual-networks) model. +* [R-101.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-101.pkl): converted copy of [MSRA's original ResNet-101](https://github.com/KaimingHe/deep-residual-networks) model. +* [X-101-32x8d.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/FAIR/X-101-32x8d.pkl): ResNeXt-101-32x8d model trained with Caffe2 at FB. +* [R-50.pkl (torchvision)](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/torchvision/R-50.pkl): converted copy of [torchvision's ResNet-50](https://pytorch.org/docs/stable/torchvision/models.html#torchvision.models.resnet50) model. + More details can be found in [the conversion script](tools/convert-torchvision-to-d2.py). + +Note that the above models have __different__ format from those provided in Detectron: we do not fuse BatchNorm into an affine layer. +Pretrained models in Detectron's format can still be used. For example: +* [X-152-32x8d-IN5k.pkl](https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl): + ResNeXt-152-32x8d model trained on ImageNet-5k with Caffe2 at FB (see ResNeXt paper for details on ImageNet-5k). +* [R-50-GN.pkl](https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl): + ResNet-50 with Group Normalization. +* [R-101-GN.pkl](https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl): + ResNet-101 with Group Normalization. + +These models require slightly different settings regarding normalization and architecture. See the model zoo configs for reference. + +#### License + +All models available for download through this document are licensed under the +[Creative Commons Attribution-ShareAlike 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/). + +### COCO Object Detection Baselines + +#### Faster R-CNN: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
model iddownload
R50-C41x0.5510.1024.835.7137257644model | metrics
R50-DC51x0.3800.0685.037.3137847829model | metrics
R50-FPN1x0.2100.0383.037.9137257794model | metrics
R50-C43x0.5430.1044.838.4137849393model | metrics
R50-DC53x0.3780.0705.039.0137849425model | metrics
R50-FPN3x0.2090.0383.040.2137849458model | metrics
R101-C43x0.6190.1395.941.1138204752model | metrics
R101-DC53x0.4520.0866.140.6138204841model | metrics
R101-FPN3x0.2860.0514.142.0137851257model | metrics
X101-FPN3x0.6380.0986.743.0139173657model | metrics
+ +#### RetinaNet: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
model iddownload
R501x0.2050.0414.137.4190397773model | metrics
R503x0.2050.0414.138.7190397829model | metrics
R1013x0.2910.0545.240.4190397697model | metrics
+ + +#### RPN & Fast R-CNN: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
prop.
AR
model iddownload
RPN R50-C41x0.1300.0341.551.6137258005model | metrics
RPN R50-FPN1x0.1860.0322.758.0137258492model | metrics
Fast R-CNN R50-FPN1x0.1400.0292.637.8137635226model | metrics
+ +### COCO Instance Segmentation Baselines with Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
R50-C41x0.5840.1105.236.832.2137259246model | metrics
R50-DC51x0.4710.0766.538.334.2137260150model | metrics
R50-FPN1x0.2610.0433.438.635.2137260431model | metrics
R50-C43x0.5750.1115.239.834.4137849525model | metrics
R50-DC53x0.4700.0766.540.035.9137849551model | metrics
R50-FPN3x0.2610.0433.441.037.2137849600model | metrics
R101-C43x0.6520.1456.342.636.7138363239model | metrics
R101-DC53x0.5450.0927.641.937.3138363294model | metrics
R101-FPN3x0.3400.0564.642.938.6138205316model | metrics
X101-FPN3x0.6900.1037.244.339.5139653917model | metrics
+ + + +#### New baselines using Large-Scale Jitter and Longer Training Schedule + +The following baselines of COCO Instance Segmentation with Mask R-CNN are generated +using a longer training schedule and large-scale jitter as described in Google's +[Simple Copy-Paste Data Augmentation](https://arxiv.org/pdf/2012.07177.pdf) paper. These +models are trained from scratch using random initialization. These baselines exceed the +previous Mask R-CNN baselines. + +In the following table, one epoch consists of training on 118000 COCO images. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Nameepochstrain
time
(s/im)
inference
time
(s/im)
box
AP
mask
AP
model iddownload
R50-FPN1000.3760.06944.640.342047764model | metrics
R50-FPN2000.3760.06946.341.742047638model | metrics
R50-FPN4000.3760.06947.442.542019571model | metrics
R101-FPN1000.5180.07346.441.642025812model | metrics
R101-FPN2000.5180.07348.043.142131867model | metrics
R101-FPN4000.5180.07348.943.742073830model | metrics
regnetx_4gf_dds_FPN1000.4740.07146.041.342047771model | metrics
regnetx_4gf_dds_FPN2000.4740.07148.143.142132721model | metrics
regnetx_4gf_dds_FPN4000.4740.07148.643.542025447model | metrics
regnety_4gf_dds_FPN1000.4870.07346.141.642047784model | metrics
regnety_4gf_dds_FPN2000.4870.07247.843.042047642model | metrics
regnety_4gf_dds_FPN4000.4870.07248.243.342045954model | metrics
+ +### COCO Person Keypoint Detection Baselines with Keypoint R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
kp.
AP
model iddownload
R50-FPN1x0.3150.0725.053.664.0137261548model | metrics
R50-FPN3x0.3160.0665.055.465.5137849621model | metrics
R101-FPN3x0.3900.0766.156.466.1138363331model | metrics
X101-FPN3x0.7380.1218.757.366.0139686956model | metrics
+ +### COCO Panoptic Segmentation Baselines with Panoptic FPN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
PQmodel iddownload
R50-FPN1x0.3040.0534.837.634.739.4139514544model | metrics
R50-FPN3x0.3020.0534.840.036.541.5139514569model | metrics
R101-FPN3x0.3920.0666.042.438.543.0139514519model | metrics
+ + +### LVIS Instance Segmentation Baselines with Mask R-CNN + +Mask R-CNN baselines on the [LVIS dataset](https://lvisdataset.org), v0.5. +These baselines are described in Table 3(c) of the [LVIS paper](https://arxiv.org/abs/1908.03195). + +NOTE: the 1x schedule here has the same amount of __iterations__ as the COCO 1x baselines. +They are roughly 24 epochs of LVISv0.5 data. +The final results of these configs have large variance across different runs. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
R50-FPN1x0.2920.1077.123.624.4144219072model | metrics
R101-FPN1x0.3710.1147.825.625.9144219035model | metrics
X101-FPN1x0.7120.15110.226.727.1144219108model | metrics
+ + + +### Cityscapes & Pascal VOC Baselines + +Simple baselines for +* Mask R-CNN on Cityscapes instance segmentation (initialized from COCO pre-training, then trained on Cityscapes fine annotations only) +* Faster R-CNN on PASCAL VOC object detection (trained on VOC 2007 train+val + VOC 2012 train+val, tested on VOC 2007 using 11-point interpolated AP) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Nametrain
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
box
AP50
mask
AP
model iddownload
R50-FPN, Cityscapes0.2400.0784.436.5142423278model | metrics
R50-C4, VOC0.5370.0814.851.980.3142202221model | metrics
+ + + +### Other Settings + +Ablations for Deformable Conv and Cascade R-CNN: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
Baseline R50-FPN1x0.2610.0433.438.635.2137260431model | metrics
Deformable Conv1x0.3420.0483.541.537.5138602867model | metrics
Cascade R-CNN1x0.3170.0524.042.136.4138602847model | metrics
Baseline R50-FPN3x0.2610.0433.441.037.2137849600model | metrics
Deformable Conv3x0.3490.0473.542.738.5144998336model | metrics
Cascade R-CNN3x0.3280.0534.044.338.5144998488model | metrics
+ + +Ablations for normalization methods, and a few models trained from scratch following [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883). +(Note: The baseline uses `2fc` head while the others use [`4conv1fc` head](https://arxiv.org/abs/1803.08494)) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
Baseline R50-FPN3x0.2610.0433.441.037.2137849600model | metrics
GN3x0.3090.0605.642.638.6138602888model | metrics
SyncBN3x0.3450.0535.541.937.8169527823model | metrics
GN (from scratch)3x0.3380.0617.239.936.6138602908model | metrics
GN (from scratch)9xN/A0.0617.243.739.6183808979model | metrics
SyncBN (from scratch)9xN/A0.0557.243.639.3184226666model | metrics
+ + +A few very large models trained for a long time, for demo purposes. They are trained using multiple machines: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Nameinference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
PQmodel iddownload
Panoptic FPN R1010.09811.447.441.346.1139797668model | metrics
Mask R-CNN X1520.23415.150.244.018131413model | metrics
above + test-time aug.51.945.9
diff --git a/data_processing/detectron2/README.md b/data_processing/detectron2/README.md new file mode 100644 index 0000000..75db3c5 --- /dev/null +++ b/data_processing/detectron2/README.md @@ -0,0 +1,68 @@ + + + + Support Ukraine - Help Provide Humanitarian Aid to Ukraine. + + +Detectron2 is Facebook AI Research's next generation library +that provides state-of-the-art detection and segmentation algorithms. +It is the successor of +[Detectron](https://github.com/facebookresearch/Detectron/) +and [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/). +It supports a number of computer vision research projects and production applications in Facebook. + +
+ +
+
+ +## Learn More about Detectron2 + +Explain Like I’m 5: Detectron2 | Using Machine Learning with Detectron2 +:-------------------------:|:-------------------------: +[![Explain Like I’m 5: Detectron2](https://img.youtube.com/vi/1oq1Ye7dFqc/0.jpg)](https://www.youtube.com/watch?v=1oq1Ye7dFqc) | [![Using Machine Learning with Detectron2](https://img.youtube.com/vi/eUSgtfK4ivk/0.jpg)](https://www.youtube.com/watch?v=eUSgtfK4ivk) + +## What's New +* Includes new capabilities such as panoptic segmentation, Densepose, Cascade R-CNN, rotated bounding boxes, PointRend, + DeepLab, ViTDet, MViTv2 etc. +* Used as a library to support building [research projects](projects/) on top of it. +* Models can be exported to TorchScript format or Caffe2 format for deployment. +* It [trains much faster](https://detectron2.readthedocs.io/notes/benchmarks.html). + +See our [blog post](https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/) +to see more demos and learn about detectron2. + +## Installation + +See [installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). + +## Getting Started + +See [Getting Started with Detectron2](https://detectron2.readthedocs.io/tutorials/getting_started.html), +and the [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) +to learn about basic usage. + +Learn more at our [documentation](https://detectron2.readthedocs.org). +And see [projects/](projects/) for some projects that are built on top of detectron2. + +## Model Zoo and Baselines + +We provide a large set of baseline results and trained models available for download in the [Detectron2 Model Zoo](MODEL_ZOO.md). + +## License + +Detectron2 is released under the [Apache 2.0 license](LICENSE). + +## Citing Detectron2 + +If you use Detectron2 in your research or wish to refer to the baseline results published in the [Model Zoo](MODEL_ZOO.md), please use the following BibTeX entry. + +```BibTeX +@misc{wu2019detectron2, + author = {Yuxin Wu and Alexander Kirillov and Francisco Massa and + Wan-Yen Lo and Ross Girshick}, + title = {Detectron2}, + howpublished = {\url{https://github.com/facebookresearch/detectron2}}, + year = {2019} +} +``` diff --git a/data_processing/detectron2/configs/Base-RCNN-C4.yaml b/data_processing/detectron2/configs/Base-RCNN-C4.yaml new file mode 100644 index 0000000..fbf34a0 --- /dev/null +++ b/data_processing/detectron2/configs/Base-RCNN-C4.yaml @@ -0,0 +1,18 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + RPN: + PRE_NMS_TOPK_TEST: 6000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "Res5ROIHeads" +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/configs/Base-RCNN-DilatedC5.yaml b/data_processing/detectron2/configs/Base-RCNN-DilatedC5.yaml new file mode 100644 index 0000000..c0d6d16 --- /dev/null +++ b/data_processing/detectron2/configs/Base-RCNN-DilatedC5.yaml @@ -0,0 +1,31 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + RESNETS: + OUT_FEATURES: ["res5"] + RES5_DILATION: 2 + RPN: + IN_FEATURES: ["res5"] + PRE_NMS_TOPK_TEST: 6000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["res5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/configs/Base-RCNN-FPN.yaml b/data_processing/detectron2/configs/Base-RCNN-FPN.yaml new file mode 100644 index 0000000..3e020f2 --- /dev/null +++ b/data_processing/detectron2/configs/Base-RCNN-FPN.yaml @@ -0,0 +1,42 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/configs/Base-RetinaNet.yaml b/data_processing/detectron2/configs/Base-RetinaNet.yaml new file mode 100644 index 0000000..8b45b98 --- /dev/null +++ b/data_processing/detectron2/configs/Base-RetinaNet.yaml @@ -0,0 +1,25 @@ +MODEL: + META_ARCHITECTURE: "RetinaNet" + BACKBONE: + NAME: "build_retinanet_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"] + FPN: + IN_FEATURES: ["res3", "res4", "res5"] + RETINANET: + IOU_THRESHOLDS: [0.4, 0.5] + IOU_LABELS: [0, -1, 1] + SMOOTH_L1_LOSS_BETA: 0.0 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..773ac10 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,17 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + LOAD_PROPOSALS: True + RESNETS: + DEPTH: 50 + PROPOSAL_GENERATOR: + NAME: "PrecomputedProposals" +DATASETS: + TRAIN: ("coco_2017_train",) + PROPOSAL_FILES_TRAIN: ("detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/coco_2017_train_box_proposals_21bc3a.pkl", ) + TEST: ("coco_2017_val",) + PROPOSAL_FILES_TEST: ("detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/coco_2017_val_box_proposals_ee0dad.pkl", ) +DATALOADER: + # proposals are part of the dataset_dicts, and take a lot of RAM + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml new file mode 100644 index 0000000..db142cd --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml new file mode 100644 index 0000000..bceb6b3 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml new file mode 100644 index 0000000..57a098f --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml new file mode 100644 index 0000000..f961301 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml new file mode 100644 index 0000000..bc51bce --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml new file mode 100644 index 0000000..0fe96f5 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml new file mode 100644 index 0000000..33fadeb --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..3262019 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml new file mode 100644 index 0000000..4139518 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml new file mode 100644 index 0000000..9c9b5ab --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml @@ -0,0 +1,13 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: False + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/fcos_R_50_FPN_1x.py b/data_processing/detectron2/configs/COCO-Detection/fcos_R_50_FPN_1x.py new file mode 100644 index 0000000..86f83c6 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/fcos_R_50_FPN_1x.py @@ -0,0 +1,11 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.fcos import model +from ..common.train import train + +dataloader.train.mapper.use_instance_mask = False +optimizer.lr = 0.01 + +model.backbone.bottom_up.freeze_at = 2 +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml new file mode 100644 index 0000000..4abb1b9 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "../Base-RetinaNet.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.py b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.py new file mode 100644 index 0000000..43057a8 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.py @@ -0,0 +1,11 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.retinanet import model +from ..common.train import train + +dataloader.train.mapper.use_instance_mask = False +model.backbone.bottom_up.freeze_at = 2 +optimizer.lr = 0.01 + +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml new file mode 100644 index 0000000..4a24ce3 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml @@ -0,0 +1,5 @@ +_BASE_: "../Base-RetinaNet.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml new file mode 100644 index 0000000..3b5412d --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "../Base-RetinaNet.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_C4_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_C4_1x.yaml new file mode 100644 index 0000000..e048211 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_C4_1x.yaml @@ -0,0 +1,10 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + META_ARCHITECTURE: "ProposalNetwork" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + RPN: + PRE_NMS_TOPK_TEST: 12000 + POST_NMS_TOPK_TEST: 2000 diff --git a/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..dc9c952 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Detection/rpn_R_50_FPN_1x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "ProposalNetwork" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + RPN: + POST_NMS_TOPK_TEST: 2000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml new file mode 100644 index 0000000..1a94cc4 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml new file mode 100644 index 0000000..67b70cf --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml new file mode 100644 index 0000000..1935a30 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.py b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.py new file mode 100644 index 0000000..22016be --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.py @@ -0,0 +1,8 @@ +from ..common.train import train +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.mask_rcnn_c4 import model + +model.backbone.freeze_at = 2 +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml new file mode 100644 index 0000000..a9aeb4e --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml new file mode 100644 index 0000000..38ed867 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml new file mode 100644 index 0000000..b13eefa --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml new file mode 100644 index 0000000..d401016 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-DilatedC5.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py new file mode 100644 index 0000000..40844dd --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py @@ -0,0 +1,8 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.mask_rcnn_fpn import model +from ..common.train import train + +model.backbone.bottom_up.freeze_at = 2 +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..d50fb86 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x_giou.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x_giou.yaml new file mode 100644 index 0000000..bec680e --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x_giou.yaml @@ -0,0 +1,12 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + RPN: + BBOX_REG_LOSS_TYPE: "giou" + BBOX_REG_LOSS_WEIGHT: 2.0 + ROI_BOX_HEAD: + BBOX_REG_LOSS_TYPE: "giou" + BBOX_REG_LOSS_WEIGHT: 10.0 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml new file mode 100644 index 0000000..be7d06b --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml new file mode 100644 index 0000000..d14c63f --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml @@ -0,0 +1,13 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: True + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnetx_4gf_dds_fpn_1x.py b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnetx_4gf_dds_fpn_1x.py new file mode 100644 index 0000000..d7bbdd7 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnetx_4gf_dds_fpn_1x.py @@ -0,0 +1,34 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.mask_rcnn_fpn import model +from ..common.train import train + +from detectron2.config import LazyCall as L +from detectron2.modeling.backbone import RegNet +from detectron2.modeling.backbone.regnet import SimpleStem, ResBottleneckBlock + + +# Replace default ResNet with RegNetX-4GF from the DDS paper. Config source: +# https://github.com/facebookresearch/pycls/blob/2c152a6e5d913e898cca4f0a758f41e6b976714d/configs/dds_baselines/regnetx/RegNetX-4.0GF_dds_8gpu.yaml#L4-L9 # noqa +model.backbone.bottom_up = L(RegNet)( + stem_class=SimpleStem, + stem_width=32, + block_class=ResBottleneckBlock, + depth=23, + w_a=38.65, + w_0=96, + w_m=2.43, + group_width=40, + freeze_at=2, + norm="FrozenBN", + out_features=["s1", "s2", "s3", "s4"], +) +model.pixel_std = [57.375, 57.120, 58.395] + +optimizer.weight_decay = 5e-5 +train.init_checkpoint = ( + "https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906383/RegNetX-4.0GF_dds_8gpu.pyth" +) +# RegNets benefit from enabling cudnn benchmark mode +train.cudnn_benchmark = True diff --git a/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnety_4gf_dds_fpn_1x.py b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnety_4gf_dds_fpn_1x.py new file mode 100644 index 0000000..72c6b7a --- /dev/null +++ b/data_processing/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_regnety_4gf_dds_fpn_1x.py @@ -0,0 +1,35 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco import dataloader +from ..common.models.mask_rcnn_fpn import model +from ..common.train import train + +from detectron2.config import LazyCall as L +from detectron2.modeling.backbone import RegNet +from detectron2.modeling.backbone.regnet import SimpleStem, ResBottleneckBlock + + +# Replace default ResNet with RegNetY-4GF from the DDS paper. Config source: +# https://github.com/facebookresearch/pycls/blob/2c152a6e5d913e898cca4f0a758f41e6b976714d/configs/dds_baselines/regnety/RegNetY-4.0GF_dds_8gpu.yaml#L4-L10 # noqa +model.backbone.bottom_up = L(RegNet)( + stem_class=SimpleStem, + stem_width=32, + block_class=ResBottleneckBlock, + depth=22, + w_a=31.41, + w_0=96, + w_m=2.24, + group_width=64, + se_ratio=0.25, + freeze_at=2, + norm="FrozenBN", + out_features=["s1", "s2", "s3", "s4"], +) +model.pixel_std = [57.375, 57.120, 58.395] + +optimizer.weight_decay = 5e-5 +train.init_checkpoint = ( + "https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906838/RegNetY-4.0GF_dds_8gpu.pyth" +) +# RegNets benefit from enabling cudnn benchmark mode +train.cudnn_benchmark = True diff --git a/data_processing/detectron2/configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml b/data_processing/detectron2/configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml new file mode 100644 index 0000000..4e03944 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/Base-Keypoint-RCNN-FPN.yaml @@ -0,0 +1,15 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + KEYPOINT_ON: True + ROI_HEADS: + NUM_CLASSES: 1 + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 0.5 # Keypoint AP degrades (though box AP improves) when using plain L1 loss + RPN: + # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. + # 1000 proposals per-image is found to hurt box AP. + # Therefore we increase it to 1500 per-image. + POST_NMS_TOPK_TRAIN: 1500 +DATASETS: + TRAIN: ("keypoints_coco_2017_train",) + TEST: ("keypoints_coco_2017_val",) diff --git a/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml new file mode 100644 index 0000000..9309535 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-Keypoint-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.py b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.py new file mode 100644 index 0000000..1aad53b --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.py @@ -0,0 +1,8 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco_keypoint import dataloader +from ..common.models.keypoint_rcnn_fpn import model +from ..common.train import train + +model.backbone.bottom_up.freeze_at = 2 +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..7bf85cf --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,5 @@ +_BASE_: "Base-Keypoint-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml new file mode 100644 index 0000000..a07f243 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-Keypoint-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml new file mode 100644 index 0000000..d4bfa20 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-Keypoint-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml b/data_processing/detectron2/configs/COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml new file mode 100644 index 0000000..f00d54b --- /dev/null +++ b/data_processing/detectron2/configs/COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml @@ -0,0 +1,11 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "PanopticFPN" + MASK_ON: True + SEM_SEG_HEAD: + LOSS_WEIGHT: 0.5 +DATASETS: + TRAIN: ("coco_2017_train_panoptic_separated",) + TEST: ("coco_2017_val_panoptic_separated",) +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: False diff --git a/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml new file mode 100644 index 0000000..0e01f6f --- /dev/null +++ b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-Panoptic-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.py b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.py new file mode 100644 index 0000000..40cf181 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.py @@ -0,0 +1,8 @@ +from ..common.optim import SGD as optimizer +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.data.coco_panoptic_separated import dataloader +from ..common.models.panoptic_fpn import model +from ..common.train import train + +model.backbone.bottom_up.freeze_at = 2 +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" diff --git a/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml new file mode 100644 index 0000000..6afa2c1 --- /dev/null +++ b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml @@ -0,0 +1,5 @@ +_BASE_: "Base-Panoptic-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml new file mode 100644 index 0000000..b956b3f --- /dev/null +++ b/data_processing/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-Panoptic-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/Cityscapes/mask_rcnn_R_50_FPN.yaml b/data_processing/detectron2/configs/Cityscapes/mask_rcnn_R_50_FPN.yaml new file mode 100644 index 0000000..1a7aaeb --- /dev/null +++ b/data_processing/detectron2/configs/Cityscapes/mask_rcnn_R_50_FPN.yaml @@ -0,0 +1,27 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + # WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + # For better, more stable performance initialize from COCO + WEIGHTS: "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl" + MASK_ON: True + ROI_HEADS: + NUM_CLASSES: 8 +# This is similar to the setting used in Mask R-CNN paper, Appendix A +# But there are some differences, e.g., we did not initialize the output +# layer using the corresponding classes from COCO +INPUT: + MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024) + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 1024 + MAX_SIZE_TRAIN: 2048 + MAX_SIZE_TEST: 2048 +DATASETS: + TRAIN: ("cityscapes_fine_instance_seg_train",) + TEST: ("cityscapes_fine_instance_seg_val",) +SOLVER: + BASE_LR: 0.01 + STEPS: (18000,) + MAX_ITER: 24000 + IMS_PER_BATCH: 8 +TEST: + EVAL_PERIOD: 8000 diff --git a/data_processing/detectron2/configs/Detectron1-Comparisons/README.md b/data_processing/detectron2/configs/Detectron1-Comparisons/README.md new file mode 100644 index 0000000..924fd00 --- /dev/null +++ b/data_processing/detectron2/configs/Detectron1-Comparisons/README.md @@ -0,0 +1,84 @@ + +Detectron2 model zoo's experimental settings and a few implementation details are different from Detectron. + +The differences in implementation details are shared in +[Compatibility with Other Libraries](../../docs/notes/compatibility.md). + +The differences in model zoo's experimental settings include: +* Use scale augmentation during training. This improves AP with lower training cost. +* Use L1 loss instead of smooth L1 loss for simplicity. This sometimes improves box AP but may + affect other AP. +* Use `POOLER_SAMPLING_RATIO=0` instead of 2. This does not significantly affect AP. +* Use `ROIAlignV2`. This does not significantly affect AP. + +In this directory, we provide a few configs that __do not__ have the above changes. +They mimic Detectron's behavior as close as possible, +and provide a fair comparison of accuracy and speed against Detectron. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
kp.
AP
model iddownload
Faster R-CNN1x0.2190.0383.136.9137781054model | metrics
Keypoint R-CNN1x0.3130.0715.053.164.2137781195model | metrics
Mask R-CNN1x0.2730.0433.437.834.9137781281model | metrics
+ +## Comparisons: + +* Faster R-CNN: Detectron's AP is 36.7, similar to ours. +* Keypoint R-CNN: Detectron's AP is box 53.6, keypoint 64.2. Fixing a Detectron's + [bug](https://github.com/facebookresearch/Detectron/issues/459) lead to a drop in box AP, and can be + compensated back by some parameter tuning. +* Mask R-CNN: Detectron's AP is box 37.7, mask 33.9. We're 1 AP better in mask AP, due to more correct implementation. + See [this article](https://ppwwyyxx.com/blog/2021/Where-are-Pixels/) for details. + +For speed comparison, see [benchmarks](https://detectron2.readthedocs.io/notes/benchmarks.html). diff --git a/data_processing/detectron2/configs/Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml b/data_processing/detectron2/configs/Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml new file mode 100644 index 0000000..6ce77f1 --- /dev/null +++ b/data_processing/detectron2/configs/Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x.yaml @@ -0,0 +1,17 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + # Detectron1 uses smooth L1 loss with some magic beta values. + # The defaults are changed to L1 loss in Detectron2. + RPN: + SMOOTH_L1_BETA: 0.1111 + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 1.0 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" +INPUT: + # no scale augmentation + MIN_SIZE_TRAIN: (800, ) diff --git a/data_processing/detectron2/configs/Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..aacf868 --- /dev/null +++ b/data_processing/detectron2/configs/Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,27 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + KEYPOINT_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 1 + ROI_KEYPOINT_HEAD: + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" + # Detectron1 uses smooth L1 loss with some magic beta values. + # The defaults are changed to L1 loss in Detectron2. + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 1.0 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" + RPN: + SMOOTH_L1_BETA: 0.1111 + # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2 + # 1000 proposals per-image is found to hurt box AP. + # Therefore we increase it to 1500 per-image. + POST_NMS_TOPK_TRAIN: 1500 +DATASETS: + TRAIN: ("keypoints_coco_2017_train",) + TEST: ("keypoints_coco_2017_val",) diff --git a/data_processing/detectron2/configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml b/data_processing/detectron2/configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml new file mode 100644 index 0000000..4ea86a8 --- /dev/null +++ b/data_processing/detectron2/configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml @@ -0,0 +1,20 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + # Detectron1 uses smooth L1 loss with some magic beta values. + # The defaults are changed to L1 loss in Detectron2. + RPN: + SMOOTH_L1_BETA: 0.1111 + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 1.0 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" + ROI_MASK_HEAD: + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" +INPUT: + # no scale augmentation + MIN_SIZE_TRAIN: (800, ) diff --git a/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml new file mode 100644 index 0000000..f0c3a1b --- /dev/null +++ b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1230 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v0.5_train",) + TEST: ("lvis_v0.5_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..64b4caa --- /dev/null +++ b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 1230 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v0.5_train",) + TEST: ("lvis_v0.5_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml new file mode 100644 index 0000000..c8b822c --- /dev/null +++ b/data_processing/detectron2/configs/LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml @@ -0,0 +1,23 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1230 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v0.5_train",) + TEST: ("lvis_v0.5_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml new file mode 100644 index 0000000..ca4dd97 --- /dev/null +++ b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_101_FPN_1x.yaml @@ -0,0 +1,22 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..f313295 --- /dev/null +++ b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,22 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml new file mode 100644 index 0000000..f6528f7 --- /dev/null +++ b/data_processing/detectron2/configs/LVISv1-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml @@ -0,0 +1,26 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 # 180000 * 16 / 100000 ~ 28.8 epochs +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 diff --git a/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml new file mode 100644 index 0000000..abb33b6 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 diff --git a/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml new file mode 100644 index 0000000..e2201ad --- /dev/null +++ b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml @@ -0,0 +1,15 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml new file mode 100644 index 0000000..fc117f6 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml @@ -0,0 +1,36 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: True + WEIGHTS: "catalog://ImageNetPretrained/FAIR/X-152-32x8d-IN5k" + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 152 + DEFORM_ON_PER_STAGE: [False, True, True, True] + ROI_HEADS: + NAME: "CascadeROIHeads" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "GN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NUM_CONV: 8 + NORM: "GN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 +SOLVER: + IMS_PER_BATCH: 128 + STEPS: (35000, 45000) + MAX_ITER: 50000 + BASE_LR: 0.16 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + CROP: + ENABLED: True +TEST: + EVAL_PERIOD: 2500 diff --git a/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_cls_agnostic.yaml b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_cls_agnostic.yaml new file mode 100644 index 0000000..4c3b767 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_cls_agnostic.yaml @@ -0,0 +1,10 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_BOX_HEAD: + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + CLS_AGNOSTIC_MASK: True diff --git a/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5.yaml b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5.yaml new file mode 100644 index 0000000..04ff988 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5.yaml @@ -0,0 +1,8 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + DEFORM_ON_PER_STAGE: [False, True, True, True] # on Res3,Res4,Res5 + DEFORM_MODULATED: False diff --git a/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml new file mode 100644 index 0000000..68c0ca5 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml @@ -0,0 +1,11 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + DEFORM_ON_PER_STAGE: [False, True, True, True] # on Res3,Res4,Res5 + DEFORM_MODULATED: False +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml new file mode 100644 index 0000000..74d274e --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml @@ -0,0 +1,21 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "catalog://ImageNetPretrained/FAIR/R-50-GN" + MASK_ON: True + RESNETS: + DEPTH: 50 + NORM: "GN" + STRIDE_IN_1X1: False + FPN: + NORM: "GN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "GN" + ROI_MASK_HEAD: + NORM: "GN" +SOLVER: + # 3x schedule + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml new file mode 100644 index 0000000..11ebb07 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml @@ -0,0 +1,24 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + STRIDE_IN_1X1: True + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" +SOLVER: + # 3x schedule + STEPS: (210000, 250000) + MAX_ITER: 270000 +TEST: + PRECISE_BN: + ENABLED: True diff --git a/data_processing/detectron2/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py b/data_processing/detectron2/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py new file mode 100644 index 0000000..bdd49a4 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py @@ -0,0 +1,152 @@ +# An example config to train a mmdetection model using detectron2. + +from ..common.data.coco import dataloader +from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier +from ..common.optim import SGD as optimizer +from ..common.train import train +from ..common.data.constants import constants + +from detectron2.modeling.mmdet_wrapper import MMDetDetector +from detectron2.config import LazyCall as L + +model = L(MMDetDetector)( + detector=dict( + type="MaskRCNN", + pretrained="torchvision://resnet50", + backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=True, + style="pytorch", + ), + neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), + rpn_head=dict( + type="RPNHead", + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type="AnchorGenerator", + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + ), + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0], + ), + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type="L1Loss", loss_weight=1.0), + ), + roi_head=dict( + type="StandardRoIHead", + bbox_roi_extractor=dict( + type="SingleRoIExtractor", + roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + ), + bbox_head=dict( + type="Shared2FCBBoxHead", + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2], + ), + reg_class_agnostic=False, + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type="L1Loss", loss_weight=1.0), + ), + mask_roi_extractor=dict( + type="SingleRoIExtractor", + roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + ), + mask_head=dict( + type="FCNMaskHead", + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0), + ), + ), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type="MaxIoUAssigner", + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + ), + sampler=dict( + type="RandomSampler", + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + ), + allowed_border=-1, + pos_weight=-1, + debug=False, + ), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type="nms", iou_threshold=0.7), + min_bbox_size=0, + ), + rcnn=dict( + assigner=dict( + type="MaxIoUAssigner", + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + ), + sampler=dict( + type="RandomSampler", + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + ), + mask_size=28, + pos_weight=-1, + debug=False, + ), + ), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type="nms", iou_threshold=0.7), + min_bbox_size=0, + ), + rcnn=dict( + score_thr=0.05, + nms=dict(type="nms", iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5, + ), + ), + ), + pixel_mean=constants.imagenet_rgb256_mean, + pixel_std=constants.imagenet_rgb256_std, +) + +dataloader.train.mapper.image_format = "RGB" # torchvision pretrained model +train.init_checkpoint = None # pretrained model is loaded inside backbone diff --git a/data_processing/detectron2/configs/Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml b/data_processing/detectron2/configs/Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml new file mode 100644 index 0000000..34016ce --- /dev/null +++ b/data_processing/detectron2/configs/Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml @@ -0,0 +1,26 @@ +# A large PanopticFPN for demo purposes. +# Use GN on backbone to support semantic seg. +# Use Cascade + Deform Conv to improve localization. +_BASE_: "../COCO-PanopticSegmentation/Base-Panoptic-FPN.yaml" +MODEL: + WEIGHTS: "catalog://ImageNetPretrained/FAIR/R-101-GN" + RESNETS: + DEPTH: 101 + NORM: "GN" + DEFORM_ON_PER_STAGE: [False, True, True, True] + STRIDE_IN_1X1: False + FPN: + NORM: "GN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NORM: "GN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 +SOLVER: + STEPS: (105000, 125000) + MAX_ITER: 135000 + IMS_PER_BATCH: 32 + BASE_LR: 0.04 diff --git a/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml new file mode 100644 index 0000000..f340028 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml @@ -0,0 +1,13 @@ +_BASE_: "mask_rcnn_R_50_FPN_3x_gn.yaml" +MODEL: + # Train from random initialization. + WEIGHTS: "" + # It makes sense to divide by STD when training from scratch + # But it seems to make no difference on the results and C2's models didn't do this. + # So we keep things consistent with C2. + # PIXEL_STD: [57.375, 57.12, 58.395] + MASK_ON: True + BACKBONE: + FREEZE_AT: 0 +# NOTE: Please refer to Rethinking ImageNet Pre-training https://arxiv.org/abs/1811.08883 +# to learn what you need for training from scratch. diff --git a/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_gn.yaml b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_gn.yaml new file mode 100644 index 0000000..d90c9ff --- /dev/null +++ b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_gn.yaml @@ -0,0 +1,19 @@ +_BASE_: "mask_rcnn_R_50_FPN_3x_gn.yaml" +MODEL: + PIXEL_STD: [57.375, 57.12, 58.395] + WEIGHTS: "" + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False + BACKBONE: + FREEZE_AT: 0 +SOLVER: + # 9x schedule + IMS_PER_BATCH: 64 # 4x the standard + STEPS: (187500, 197500) # last 60/4==15k and last 20/4==5k + MAX_ITER: 202500 # 90k * 9 / 4 + BASE_LR: 0.08 +TEST: + EVAL_PERIOD: 2500 +# NOTE: Please refer to Rethinking ImageNet Pre-training https://arxiv.org/abs/1811.08883 +# to learn what you need for training from scratch. diff --git a/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn.yaml b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn.yaml new file mode 100644 index 0000000..60d4e42 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn.yaml @@ -0,0 +1,19 @@ +_BASE_: "mask_rcnn_R_50_FPN_3x_syncbn.yaml" +MODEL: + PIXEL_STD: [57.375, 57.12, 58.395] + WEIGHTS: "" + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False + BACKBONE: + FREEZE_AT: 0 +SOLVER: + # 9x schedule + IMS_PER_BATCH: 64 # 4x the standard + STEPS: (187500, 197500) # last 60/4==15k and last 20/4==5k + MAX_ITER: 202500 # 90k * 9 / 4 + BASE_LR: 0.08 +TEST: + EVAL_PERIOD: 2500 +# NOTE: Please refer to Rethinking ImageNet Pre-training https://arxiv.org/abs/1811.08883 +# to learn what you need for training from scratch. diff --git a/data_processing/detectron2/configs/Misc/semantic_R_50_FPN_1x.yaml b/data_processing/detectron2/configs/Misc/semantic_R_50_FPN_1x.yaml new file mode 100644 index 0000000..ac256e1 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/semantic_R_50_FPN_1x.yaml @@ -0,0 +1,11 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +DATASETS: + TRAIN: ("coco_2017_train_panoptic_stuffonly",) + TEST: ("coco_2017_val_panoptic_stuffonly",) +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) diff --git a/data_processing/detectron2/configs/Misc/torchvision_imagenet_R_50.py b/data_processing/detectron2/configs/Misc/torchvision_imagenet_R_50.py new file mode 100644 index 0000000..0d75305 --- /dev/null +++ b/data_processing/detectron2/configs/Misc/torchvision_imagenet_R_50.py @@ -0,0 +1,150 @@ +""" +An example config file to train a ImageNet classifier with detectron2. +Model and dataloader both come from torchvision. +This shows how to use detectron2 as a general engine for any new models and tasks. + +To run, use the following command: + +python tools/lazyconfig_train_net.py --config-file configs/Misc/torchvision_imagenet_R_50.py \ + --num-gpus 8 dataloader.train.dataset.root=/path/to/imagenet/ + +""" + + +import torch +from torch import nn +from torch.nn import functional as F +from omegaconf import OmegaConf +import torchvision +from torchvision.transforms import transforms as T +from torchvision.models.resnet import ResNet, Bottleneck +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.solver import WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params +from detectron2.config import LazyCall as L +from detectron2.model_zoo import get_config +from detectron2.data.samplers import TrainingSampler, InferenceSampler +from detectron2.evaluation import DatasetEvaluator +from detectron2.utils import comm + + +""" +Note: Here we put reusable code (models, evaluation, data) together with configs just as a +proof-of-concept, to easily demonstrate what's needed to train a ImageNet classifier in detectron2. +Writing code in configs offers extreme flexibility but is often not a good engineering practice. +In practice, you might want to put code in your project and import them instead. +""" + + +def build_data_loader(dataset, batch_size, num_workers, training=True): + return torch.utils.data.DataLoader( + dataset, + sampler=(TrainingSampler if training else InferenceSampler)(len(dataset)), + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + ) + + +class ClassificationNet(nn.Module): + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @property + def device(self): + return list(self.model.parameters())[0].device + + def forward(self, inputs): + image, label = inputs + pred = self.model(image.to(self.device)) + if self.training: + label = label.to(self.device) + return F.cross_entropy(pred, label) + else: + return pred + + +class ClassificationAcc(DatasetEvaluator): + def reset(self): + self.corr = self.total = 0 + + def process(self, inputs, outputs): + image, label = inputs + self.corr += (outputs.argmax(dim=1).cpu() == label.cpu()).sum().item() + self.total += len(label) + + def evaluate(self): + all_corr_total = comm.all_gather([self.corr, self.total]) + corr = sum(x[0] for x in all_corr_total) + total = sum(x[1] for x in all_corr_total) + return {"accuracy": corr / total} + + +# --- End of code that could be in a project and be imported + + +dataloader = OmegaConf.create() +dataloader.train = L(build_data_loader)( + dataset=L(torchvision.datasets.ImageNet)( + root="/path/to/imagenet", + split="train", + transform=L(T.Compose)( + transforms=[ + L(T.RandomResizedCrop)(size=224), + L(T.RandomHorizontalFlip)(), + T.ToTensor(), + L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ), + ), + batch_size=256 // 8, + num_workers=4, + training=True, +) + +dataloader.test = L(build_data_loader)( + dataset=L(torchvision.datasets.ImageNet)( + root="${...train.dataset.root}", + split="val", + transform=L(T.Compose)( + transforms=[ + L(T.Resize)(size=256), + L(T.CenterCrop)(size=224), + T.ToTensor(), + L(T.Normalize)(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ), + ), + batch_size=256 // 8, + num_workers=4, + training=False, +) + +dataloader.evaluator = L(ClassificationAcc)() + +model = L(ClassificationNet)( + model=(ResNet)(block=Bottleneck, layers=[3, 4, 6, 3], zero_init_residual=True) +) + + +optimizer = L(torch.optim.SGD)( + params=L(get_default_optimizer_params)(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, +) + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01, 0.001], milestones=[30, 60, 90, 100] + ), + warmup_length=1 / 100, + warmup_factor=0.1, +) + + +train = get_config("common/train.py").train +train.init_checkpoint = None +train.max_iter = 100 * 1281167 // 256 diff --git a/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yaml b/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yaml new file mode 100644 index 0000000..ea2a6ba --- /dev/null +++ b/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yaml @@ -0,0 +1,18 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 20 +INPUT: + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 +DATASETS: + TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') + TEST: ('voc_2007_test',) +SOLVER: + STEPS: (12000, 16000) + MAX_ITER: 18000 # 17.4 epochs + WARMUP_ITERS: 100 diff --git a/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml b/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml new file mode 100644 index 0000000..e554cab --- /dev/null +++ b/data_processing/detectron2/configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml @@ -0,0 +1,18 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 20 +INPUT: + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 +DATASETS: + TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') + TEST: ('voc_2007_test',) +SOLVER: + STEPS: (12000, 16000) + MAX_ITER: 18000 # 17.4 epochs + WARMUP_ITERS: 100 diff --git a/data_processing/detectron2/configs/common/README.md b/data_processing/detectron2/configs/common/README.md new file mode 100644 index 0000000..912cc29 --- /dev/null +++ b/data_processing/detectron2/configs/common/README.md @@ -0,0 +1,6 @@ +This directory provides definitions for a few common models, dataloaders, scheduler, +and optimizers that are often used in training. +The definition of these objects are provided in the form of lazy instantiation: +their arguments can be edited by users before constructing the objects. + +They can be imported, or loaded by `model_zoo.get_config` API in users' own configs. diff --git a/data_processing/detectron2/configs/common/coco_schedule.py b/data_processing/detectron2/configs/common/coco_schedule.py new file mode 100644 index 0000000..355e66a --- /dev/null +++ b/data_processing/detectron2/configs/common/coco_schedule.py @@ -0,0 +1,47 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + + +def default_X_scheduler(num_X): + """ + Returns the config for a default multi-step LR scheduler such as "1x", "3x", + commonly referred to in papers, where every 1x has the total length of 1440k + training images (~12 COCO epochs). LR is decayed twice at the end of training + following the strategy defined in "Rethinking ImageNet Pretraining", Sec 4. + + Args: + num_X: a positive real number + + Returns: + DictConfig: configs that define the multiplier for LR during training + """ + # total number of iterations assuming 16 batch size, using 1440000/16=90000 + total_steps_16bs = num_X * 90000 + + if num_X <= 2: + scheduler = L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + # note that scheduler is scale-invariant. This is equivalent to + # milestones=[6, 8, 9] + milestones=[60000, 80000, 90000], + ) + else: + scheduler = L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[total_steps_16bs - 60000, total_steps_16bs - 20000, total_steps_16bs], + ) + return L(WarmupParamScheduler)( + scheduler=scheduler, + warmup_length=1000 / total_steps_16bs, + warmup_method="linear", + warmup_factor=0.001, + ) + + +lr_multiplier_1x = default_X_scheduler(1) +lr_multiplier_2x = default_X_scheduler(2) +lr_multiplier_3x = default_X_scheduler(3) +lr_multiplier_6x = default_X_scheduler(6) +lr_multiplier_9x = default_X_scheduler(9) diff --git a/data_processing/detectron2/configs/common/data/coco.py b/data_processing/detectron2/configs/common/data/coco.py new file mode 100644 index 0000000..703c438 --- /dev/null +++ b/data_processing/detectron2/configs/common/data/coco.py @@ -0,0 +1,48 @@ +from omegaconf import OmegaConf + +import detectron2.data.transforms as T +from detectron2.config import LazyCall as L +from detectron2.data import ( + DatasetMapper, + build_detection_test_loader, + build_detection_train_loader, + get_detection_dataset_dicts, +) +from detectron2.evaluation import COCOEvaluator + +dataloader = OmegaConf.create() + +dataloader.train = L(build_detection_train_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_train"), + mapper=L(DatasetMapper)( + is_train=True, + augmentations=[ + L(T.ResizeShortestEdge)( + short_edge_length=(640, 672, 704, 736, 768, 800), + sample_style="choice", + max_size=1333, + ), + L(T.RandomFlip)(horizontal=True), + ], + image_format="BGR", + use_instance_mask=True, + ), + total_batch_size=16, + num_workers=4, +) + +dataloader.test = L(build_detection_test_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_val", filter_empty=False), + mapper=L(DatasetMapper)( + is_train=False, + augmentations=[ + L(T.ResizeShortestEdge)(short_edge_length=800, max_size=1333), + ], + image_format="${...train.mapper.image_format}", + ), + num_workers=4, +) + +dataloader.evaluator = L(COCOEvaluator)( + dataset_name="${..test.dataset.names}", +) diff --git a/data_processing/detectron2/configs/common/data/coco_keypoint.py b/data_processing/detectron2/configs/common/data/coco_keypoint.py new file mode 100644 index 0000000..b4ceb06 --- /dev/null +++ b/data_processing/detectron2/configs/common/data/coco_keypoint.py @@ -0,0 +1,13 @@ +from detectron2.data.detection_utils import create_keypoint_hflip_indices + +from .coco import dataloader + +dataloader.train.dataset.min_keypoints = 1 +dataloader.train.dataset.names = "keypoints_coco_2017_train" +dataloader.test.dataset.names = "keypoints_coco_2017_val" + +dataloader.train.mapper.update( + use_instance_mask=False, + use_keypoint=True, + keypoint_hflip_indices=create_keypoint_hflip_indices(dataloader.train.dataset.names), +) diff --git a/data_processing/detectron2/configs/common/data/coco_panoptic_separated.py b/data_processing/detectron2/configs/common/data/coco_panoptic_separated.py new file mode 100644 index 0000000..5ccbc77 --- /dev/null +++ b/data_processing/detectron2/configs/common/data/coco_panoptic_separated.py @@ -0,0 +1,26 @@ +from detectron2.config import LazyCall as L +from detectron2.evaluation import ( + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + SemSegEvaluator, +) + +from .coco import dataloader + +dataloader.train.dataset.names = "coco_2017_train_panoptic_separated" +dataloader.train.dataset.filter_empty = False +dataloader.test.dataset.names = "coco_2017_val_panoptic_separated" + + +dataloader.evaluator = [ + L(COCOEvaluator)( + dataset_name="${...test.dataset.names}", + ), + L(SemSegEvaluator)( + dataset_name="${...test.dataset.names}", + ), + L(COCOPanopticEvaluator)( + dataset_name="${...test.dataset.names}", + ), +] diff --git a/data_processing/detectron2/configs/common/data/constants.py b/data_processing/detectron2/configs/common/data/constants.py new file mode 100644 index 0000000..be11cb5 --- /dev/null +++ b/data_processing/detectron2/configs/common/data/constants.py @@ -0,0 +1,9 @@ +constants = dict( + imagenet_rgb256_mean=[123.675, 116.28, 103.53], + imagenet_rgb256_std=[58.395, 57.12, 57.375], + imagenet_bgr256_mean=[103.530, 116.280, 123.675], + # When using pre-trained models in Detectron1 or any MSRA models, + # std has been absorbed into its conv1 weights, so the std needs to be set 1. + # Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + imagenet_bgr256_std=[1.0, 1.0, 1.0], +) diff --git a/data_processing/detectron2/configs/common/models/cascade_rcnn.py b/data_processing/detectron2/configs/common/models/cascade_rcnn.py new file mode 100644 index 0000000..c7372a8 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/cascade_rcnn.py @@ -0,0 +1,36 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead, CascadeROIHeads + +from .mask_rcnn_fpn import model + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[], + fc_dims=[1024, 1024], + ) + for k in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/data_processing/detectron2/configs/common/models/fcos.py b/data_processing/detectron2/configs/common/models/fcos.py new file mode 100644 index 0000000..1c75202 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/fcos.py @@ -0,0 +1,23 @@ +from detectron2.modeling.meta_arch.fcos import FCOS, FCOSHead + +from .retinanet import model + +model._target_ = FCOS + +del model.anchor_generator +del model.box2box_transform +del model.anchor_matcher +del model.input_format + +# Use P5 instead of C5 to compute P6/P7 +# (Sec 2.2 of https://arxiv.org/abs/2006.09214) +model.backbone.top_block.in_feature = "p5" +model.backbone.top_block.in_channels = 256 + +# New score threshold determined based on sqrt(cls_score * centerness) +model.test_score_thresh = 0.2 +model.test_nms_thresh = 0.6 + +model.head._target_ = FCOSHead +del model.head.num_anchors +model.head.norm = "GN" diff --git a/data_processing/detectron2/configs/common/models/keypoint_rcnn_fpn.py b/data_processing/detectron2/configs/common/models/keypoint_rcnn_fpn.py new file mode 100644 index 0000000..56b3994 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/keypoint_rcnn_fpn.py @@ -0,0 +1,33 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead + +from .mask_rcnn_fpn import model + +[model.roi_heads.pop(x) for x in ["mask_in_features", "mask_pooler", "mask_head"]] + +model.roi_heads.update( + num_classes=1, + keypoint_in_features=["p2", "p3", "p4", "p5"], + keypoint_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + keypoint_head=L(KRCNNConvDeconvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_keypoints=17, + conv_dims=[512] * 8, + loss_normalizer="visible", + ), +) + +# Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. +# 1000 proposals per-image is found to hurt box AP. +# Therefore we increase it to 1500 per-image. +model.proposal_generator.post_nms_topk = (1500, 1000) + +# Keypoint AP degrades (though box AP improves) when using plain L1 loss +model.roi_heads.box_predictor.smooth_l1_beta = 0.5 diff --git a/data_processing/detectron2/configs/common/models/mask_rcnn_c4.py b/data_processing/detectron2/configs/common/models/mask_rcnn_c4.py new file mode 100644 index 0000000..902d5b1 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/mask_rcnn_c4.py @@ -0,0 +1,90 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone import BasicStem, BottleneckBlock, ResNet +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.proposal_generator import RPN, StandardRPNHead +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + MaskRCNNConvUpsampleHead, + Res5ROIHeads, +) + +from ..data.constants import constants + +model = L(GeneralizedRCNN)( + backbone=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=True, + norm="FrozenBN", + ), + out_features=["res4"], + ), + proposal_generator=L(RPN)( + in_features=["res4"], + head=L(StandardRPNHead)(in_channels=1024, num_anchors=15), + anchor_generator=L(DefaultAnchorGenerator)( + sizes=[[32, 64, 128, 256, 512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[16], + offset=0.0, + ), + anchor_matcher=L(Matcher)( + thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True + ), + box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), + batch_size_per_image=256, + positive_fraction=0.5, + pre_nms_topk=(12000, 6000), + post_nms_topk=(2000, 1000), + nms_thresh=0.7, + ), + roi_heads=L(Res5ROIHeads)( + num_classes=80, + batch_size_per_image=512, + positive_fraction=0.25, + proposal_matcher=L(Matcher)( + thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False + ), + in_features=["res4"], + pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 16,), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + res5=L(ResNet.make_stage)( + block_class=BottleneckBlock, + num_blocks=3, + stride_per_block=[2, 1, 1], + in_channels=1024, + bottleneck_channels=512, + out_channels=2048, + norm="FrozenBN", + stride_in_1x1=True, + ), + box_predictor=L(FastRCNNOutputLayers)( + input_shape=L(ShapeSpec)(channels="${...res5.out_channels}", height=1, width=1), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), + num_classes="${..num_classes}", + ), + mask_head=L(MaskRCNNConvUpsampleHead)( + input_shape=L(ShapeSpec)( + channels="${...res5.out_channels}", + width="${...pooler.output_size}", + height="${...pooler.output_size}", + ), + num_classes="${..num_classes}", + conv_dims=[256], + ), + ), + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, + input_format="BGR", +) diff --git a/data_processing/detectron2/configs/common/models/mask_rcnn_fpn.py b/data_processing/detectron2/configs/common/models/mask_rcnn_fpn.py new file mode 100644 index 0000000..5e5c501 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/mask_rcnn_fpn.py @@ -0,0 +1,95 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone.fpn import LastLevelMaxPool +from detectron2.modeling.backbone import BasicStem, FPN, ResNet +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.proposal_generator import RPN, StandardRPNHead +from detectron2.modeling.roi_heads import ( + StandardROIHeads, + FastRCNNOutputLayers, + MaskRCNNConvUpsampleHead, + FastRCNNConvFCHead, +) + +from ..data.constants import constants + +model = L(GeneralizedRCNN)( + backbone=L(FPN)( + bottom_up=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=True, + norm="FrozenBN", + ), + out_features=["res2", "res3", "res4", "res5"], + ), + in_features="${.bottom_up.out_features}", + out_channels=256, + top_block=L(LastLevelMaxPool)(), + ), + proposal_generator=L(RPN)( + in_features=["p2", "p3", "p4", "p5", "p6"], + head=L(StandardRPNHead)(in_channels=256, num_anchors=3), + anchor_generator=L(DefaultAnchorGenerator)( + sizes=[[32], [64], [128], [256], [512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + offset=0.0, + ), + anchor_matcher=L(Matcher)( + thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True + ), + box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), + batch_size_per_image=256, + positive_fraction=0.5, + pre_nms_topk=(2000, 1000), + post_nms_topk=(1000, 1000), + nms_thresh=0.7, + ), + roi_heads=L(StandardROIHeads)( + num_classes=80, + batch_size_per_image=512, + positive_fraction=0.25, + proposal_matcher=L(Matcher)( + thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False + ), + box_in_features=["p2", "p3", "p4", "p5"], + box_pooler=L(ROIPooler)( + output_size=7, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + box_head=L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[], + fc_dims=[1024, 1024], + ), + box_predictor=L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), + num_classes="${..num_classes}", + ), + mask_in_features=["p2", "p3", "p4", "p5"], + mask_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + mask_head=L(MaskRCNNConvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_classes="${..num_classes}", + conv_dims=[256, 256, 256, 256, 256], + ), + ), + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, + input_format="BGR", +) diff --git a/data_processing/detectron2/configs/common/models/mask_rcnn_vitdet.py b/data_processing/detectron2/configs/common/models/mask_rcnn_vitdet.py new file mode 100644 index 0000000..d6f5244 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/mask_rcnn_vitdet.py @@ -0,0 +1,59 @@ +from functools import partial +import torch.nn as nn +from detectron2.config import LazyCall as L +from detectron2.modeling import ViT, SimpleFeaturePyramid +from detectron2.modeling.backbone.fpn import LastLevelMaxPool + +from .mask_rcnn_fpn import model +from ..data.constants import constants + +model.pixel_mean = constants.imagenet_rgb256_mean +model.pixel_std = constants.imagenet_rgb256_std +model.input_format = "RGB" + +# Base +embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 +# Creates Simple Feature Pyramid from ViT backbone +model.backbone = L(SimpleFeaturePyramid)( + net=L(ViT)( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + drop_path_rate=dp, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True, + out_feature="last_feat", + ), + in_feature="${.net.out_feature}", + out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), + top_block=L(LastLevelMaxPool)(), + norm="LN", + square_pad=1024, +) + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# 4conv1fc box head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] diff --git a/data_processing/detectron2/configs/common/models/panoptic_fpn.py b/data_processing/detectron2/configs/common/models/panoptic_fpn.py new file mode 100644 index 0000000..88f55d2 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/panoptic_fpn.py @@ -0,0 +1,20 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling import PanopticFPN +from detectron2.modeling.meta_arch.semantic_seg import SemSegFPNHead + +from .mask_rcnn_fpn import model + +model._target_ = PanopticFPN +model.sem_seg_head = L(SemSegFPNHead)( + input_shape={ + f: L(ShapeSpec)(stride=s, channels="${....backbone.out_channels}") + for f, s in zip(["p2", "p3", "p4", "p5"], [4, 8, 16, 32]) + }, + ignore_value=255, + num_classes=54, # COCO stuff + 1 + conv_dims=128, + common_stride=4, + loss_weight=0.5, + norm="GN", +) diff --git a/data_processing/detectron2/configs/common/models/retinanet.py b/data_processing/detectron2/configs/common/models/retinanet.py new file mode 100644 index 0000000..784e531 --- /dev/null +++ b/data_processing/detectron2/configs/common/models/retinanet.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.meta_arch import RetinaNet +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone.fpn import LastLevelP6P7 +from detectron2.modeling.backbone import BasicStem, FPN, ResNet +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.meta_arch.retinanet import RetinaNetHead + +from ..data.constants import constants + +model = L(RetinaNet)( + backbone=L(FPN)( + bottom_up=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=True, + norm="FrozenBN", + ), + out_features=["res3", "res4", "res5"], + ), + in_features=["res3", "res4", "res5"], + out_channels=256, + top_block=L(LastLevelP6P7)(in_channels=2048, out_channels="${..out_channels}"), + ), + head=L(RetinaNetHead)( + # Shape for each input feature map + input_shape=[ShapeSpec(channels=256)] * 5, + num_classes="${..num_classes}", + conv_dims=[256, 256, 256, 256], + prior_prob=0.01, + num_anchors=9, + ), + anchor_generator=L(DefaultAnchorGenerator)( + sizes=[[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128], + offset=0.0, + ), + box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), + anchor_matcher=L(Matcher)( + thresholds=[0.4, 0.5], labels=[0, -1, 1], allow_low_quality_matches=True + ), + num_classes=80, + head_in_features=["p3", "p4", "p5", "p6", "p7"], + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, + input_format="BGR", +) diff --git a/data_processing/detectron2/configs/common/optim.py b/data_processing/detectron2/configs/common/optim.py new file mode 100644 index 0000000..6cf43e8 --- /dev/null +++ b/data_processing/detectron2/configs/common/optim.py @@ -0,0 +1,28 @@ +import torch + +from detectron2.config import LazyCall as L +from detectron2.solver.build import get_default_optimizer_params + +SGD = L(torch.optim.SGD)( + params=L(get_default_optimizer_params)( + # params.model is meant to be set to the model object, before instantiating + # the optimizer. + weight_decay_norm=0.0 + ), + lr=0.02, + momentum=0.9, + weight_decay=1e-4, +) + + +AdamW = L(torch.optim.AdamW)( + params=L(get_default_optimizer_params)( + # params.model is meant to be set to the model object, before instantiating + # the optimizer. + base_lr="${..lr}", + weight_decay_norm=0.0, + ), + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, +) diff --git a/data_processing/detectron2/configs/common/train.py b/data_processing/detectron2/configs/common/train.py new file mode 100644 index 0000000..b6ed02b --- /dev/null +++ b/data_processing/detectron2/configs/common/train.py @@ -0,0 +1,18 @@ +# Common training-related configs that are designed for "tools/lazyconfig_train_net.py" +# You can use your own instead, together with your own train_net.py +train = dict( + output_dir="./output", + init_checkpoint="", + max_iter=90000, + amp=dict(enabled=False), # options for Automatic Mixed Precision + ddp=dict( # options for DistributedDataParallel + broadcast_buffers=False, + find_unused_parameters=False, + fp16_compression=False, + ), + checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer + eval_period=5000, + log_period=20, + device="cuda" + # ... +) diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ.py new file mode 100644 index 0000000..3740e9b --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ.py @@ -0,0 +1,9 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +model.backbone.bottom_up.stages.depth = 101 diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ.py new file mode 100644 index 0000000..18e5f07 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_R_101_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 2 # 100ep -> 200ep + +lr_multiplier.scheduler.milestones = [ + milestone * 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py new file mode 100644 index 0000000..63c54ee --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_R_101_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 4 # 100ep -> 400ep + +lr_multiplier.scheduler.milestones = [ + milestone * 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py new file mode 100644 index 0000000..df7a2ae --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py @@ -0,0 +1,72 @@ +import detectron2.data.transforms as T +from detectron2.config.lazy import LazyCall as L +from detectron2.layers.batch_norm import NaiveSyncBatchNorm +from detectron2.solver import WarmupParamScheduler +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from ..common.data.coco import dataloader +from ..common.models.mask_rcnn_fpn import model +from ..common.optim import SGD as optimizer +from ..common.train import train + +# train from scratch +train.init_checkpoint = "" +train.amp.enabled = True +train.ddp.fp16_compression = True +model.backbone.bottom_up.freeze_at = 0 + +# SyncBN +# fmt: off +model.backbone.bottom_up.stem.norm = \ + model.backbone.bottom_up.stages.norm = \ + model.backbone.norm = "SyncBN" + +# Using NaiveSyncBatchNorm becase heads may have empty input. That is not supported by +# torch.nn.SyncBatchNorm. We can remove this after +# https://github.com/pytorch/pytorch/issues/36530 is fixed. +model.roi_heads.box_head.conv_norm = \ + model.roi_heads.mask_head.conv_norm = lambda c: NaiveSyncBatchNorm(c, + stats_mode="N") +# fmt: on + +# 2conv in RPN: +# https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/detection/modeling/architecture/heads.py#L95-L97 # noqa: E501, B950 +model.proposal_generator.head.conv_dims = [-1, -1] + +# 4conv1fc box head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] + +# resize_and_crop_image in: +# https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/detection/utils/input_utils.py#L127 # noqa: E501, B950 +image_size = 1024 +dataloader.train.mapper.augmentations = [ + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), + L(T.RandomFlip)(horizontal=True), +] + +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True + +# larger batch-size. +dataloader.train.total_batch_size = 64 + +# Equivalent to 100 epochs. +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=500 / train.max_iter, + warmup_factor=0.067, +) + +optimizer.lr = 0.1 +optimizer.weight_decay = 4e-5 diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ.py new file mode 100644 index 0000000..2a7c376 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 2 # 100ep -> 200ep + +lr_multiplier.scheduler.milestones = [ + milestone * 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ.py new file mode 100644 index 0000000..97586b8 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 4 # 100ep -> 400ep + +lr_multiplier.scheduler.milestones = [ + milestone * 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_50ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_50ep_LSJ.py new file mode 100644 index 0000000..2ca1ede --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_R_50_FPN_50ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter //= 2 # 100ep -> 50ep + +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ.py new file mode 100644 index 0000000..ef0b6d1 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ.py @@ -0,0 +1,29 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) +from detectron2.config import LazyCall as L +from detectron2.modeling.backbone import RegNet +from detectron2.modeling.backbone.regnet import SimpleStem, ResBottleneckBlock + +# Config source: +# https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_regnetx_4gf_dds_fpn_1x.py # noqa +model.backbone.bottom_up = L(RegNet)( + stem_class=SimpleStem, + stem_width=32, + block_class=ResBottleneckBlock, + depth=23, + w_a=38.65, + w_0=96, + w_m=2.43, + group_width=40, + norm="SyncBN", + out_features=["s1", "s2", "s3", "s4"], +) +model.pixel_std = [57.375, 57.120, 58.395] + +# RegNets benefit from enabling cudnn benchmark mode +train.cudnn_benchmark = True diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py new file mode 100644 index 0000000..731320e --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 2 # 100ep -> 200ep + +lr_multiplier.scheduler.milestones = [ + milestone * 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ.py new file mode 100644 index 0000000..8f369a2 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 4 # 100ep -> 400ep + +lr_multiplier.scheduler.milestones = [ + milestone * 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ.py new file mode 100644 index 0000000..ba2c327 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ.py @@ -0,0 +1,30 @@ +from .mask_rcnn_R_50_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) +from detectron2.config import LazyCall as L +from detectron2.modeling.backbone import RegNet +from detectron2.modeling.backbone.regnet import SimpleStem, ResBottleneckBlock + +# Config source: +# https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_regnety_4gf_dds_fpn_1x.py # noqa +model.backbone.bottom_up = L(RegNet)( + stem_class=SimpleStem, + stem_width=32, + block_class=ResBottleneckBlock, + depth=22, + w_a=31.41, + w_0=96, + w_m=2.24, + group_width=64, + se_ratio=0.25, + norm="SyncBN", + out_features=["s1", "s2", "s3", "s4"], +) +model.pixel_std = [57.375, 57.120, 58.395] + +# RegNets benefit from enabling cudnn benchmark mode +train.cudnn_benchmark = True diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ.py new file mode 100644 index 0000000..b867cc8 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 2 # 100ep -> 200ep + +lr_multiplier.scheduler.milestones = [ + milestone * 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ.py b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ.py new file mode 100644 index 0000000..7b86ea8 --- /dev/null +++ b/data_processing/detectron2/configs/new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ.py @@ -0,0 +1,14 @@ +from .mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +train.max_iter *= 4 # 100ep -> 400ep + +lr_multiplier.scheduler.milestones = [ + milestone * 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/configs/quick_schedules/README.md b/data_processing/detectron2/configs/quick_schedules/README.md new file mode 100644 index 0000000..4e6c82e --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/README.md @@ -0,0 +1,8 @@ +These are quick configs for performance or accuracy regression tracking purposes. + +* `*instance_test.yaml`: can train on 2 GPUs. They are used to test whether the training can + successfully finish. They are not expected to produce reasonable training results. +* `*inference_acc_test.yaml`: They should be run using `--eval-only`. They run inference using pre-trained models and verify + the results are as expected. +* `*training_acc_test.yaml`: They should be trained on 8 GPUs. They finish in about an hour and verify the training accuracy + is within the normal range. diff --git a/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..fc5a411 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml" +MODEL: + WEIGHTS: "detectron2://Misc/cascade_mask_rcnn_R_50_FPN_3x/144998488/model_final_480dd8.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 50.18, 0.02], ["segm", "AP", 43.87, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..e41a0fe --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/cascade_mask_rcnn_R_50_FPN_instant_test.yaml @@ -0,0 +1,11 @@ +_BASE_: "../Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml" +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..a2f37e5 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-Detection/fast_rcnn_R_50_FPN_1x/137635226/model_final_e5f7ce.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 45.70, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..52fc0ec --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/fast_rcnn_R_50_FPN_instant_test.yaml @@ -0,0 +1,15 @@ +_BASE_: "../COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +DATASETS: + TRAIN: ("coco_2017_val_100",) + PROPOSAL_FILES_TRAIN: ("detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/coco_2017_val_box_proposals_ee0dad.pkl", ) + TEST: ("coco_2017_val_100",) + PROPOSAL_FILES_TEST: ("detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/coco_2017_val_box_proposals_ee0dad.pkl", ) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..14cf2aa --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl" +DATASETS: + TEST: ("keypoints_coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 52.47, 0.02], ["keypoints", "AP", 67.36, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..3dd209f --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_instant_test.yaml @@ -0,0 +1,16 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + KEYPOINT_ON: True + ROI_HEADS: + NUM_CLASSES: 1 +DATASETS: + TRAIN: ("keypoints_coco_2017_val_100",) + TEST: ("keypoints_coco_2017_val_100",) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_normalized_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_normalized_training_acc_test.yaml new file mode 100644 index 0000000..4b92392 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_normalized_training_acc_test.yaml @@ -0,0 +1,30 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + KEYPOINT_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + NUM_CLASSES: 1 + ROI_KEYPOINT_HEAD: + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: False + LOSS_WEIGHT: 4.0 + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 1.0 # Keypoint AP degrades when using plain L1 loss + RPN: + SMOOTH_L1_BETA: 0.2 # Keypoint AP degrades when using plain L1 loss +DATASETS: + TRAIN: ("keypoints_coco_2017_val",) + TEST: ("keypoints_coco_2017_val",) +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +SOLVER: + WARMUP_FACTOR: 0.33333333 + WARMUP_ITERS: 100 + STEPS: (5500, 5800) + MAX_ITER: 6000 +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 55.35, 1.0], ["keypoints", "AP", 76.91, 1.0]] diff --git a/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_training_acc_test.yaml new file mode 100644 index 0000000..9bd9628 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/keypoint_rcnn_R_50_FPN_training_acc_test.yaml @@ -0,0 +1,28 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + KEYPOINT_ON: True + RESNETS: + DEPTH: 50 + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + NUM_CLASSES: 1 + ROI_KEYPOINT_HEAD: + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + ROI_BOX_HEAD: + SMOOTH_L1_BETA: 1.0 # Keypoint AP degrades when using plain L1 loss + RPN: + SMOOTH_L1_BETA: 0.2 # Keypoint AP degrades when using plain L1 loss +DATASETS: + TRAIN: ("keypoints_coco_2017_val",) + TEST: ("keypoints_coco_2017_val",) +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +SOLVER: + WARMUP_FACTOR: 0.33333333 + WARMUP_ITERS: 100 + STEPS: (5500, 5800) + MAX_ITER: 6000 +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 53.5, 1.0], ["keypoints", "AP", 72.4, 1.0]] diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_GCV_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_GCV_instant_test.yaml new file mode 100644 index 0000000..ab6e698 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_GCV_instant_test.yaml @@ -0,0 +1,18 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + BASE_LR: 0.001 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "value" + CLIP_VALUE: 1.0 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_inference_acc_test.yaml new file mode 100644 index 0000000..b2d5b7f --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x/137849525/model_final_4ce675.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 47.37, 0.02], ["segm", "AP", 40.99, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_instant_test.yaml new file mode 100644 index 0000000..6c4f121 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_instant_test.yaml @@ -0,0 +1,14 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + BASE_LR: 0.001 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_training_acc_test.yaml new file mode 100644 index 0000000..f68dd8f --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_C4_training_acc_test.yaml @@ -0,0 +1,22 @@ +_BASE_: "../Base-RCNN-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + MASK_ON: True +DATASETS: + TRAIN: ("coco_2017_val",) + TEST: ("coco_2017_val",) +INPUT: + MIN_SIZE_TRAIN: (600,) + MAX_SIZE_TRAIN: 1000 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1000 +SOLVER: + IMS_PER_BATCH: 8 # base uses 16 + WARMUP_FACTOR: 0.33333 + WARMUP_ITERS: 100 + STEPS: (11000, 11600) + MAX_ITER: 12000 +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 41.88, 0.7], ["segm", "AP", 33.79, 0.5]] diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_DC5_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_DC5_inference_acc_test.yaml new file mode 100644 index 0000000..e3ce6cf --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_DC5_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x/137849551/model_final_84107b.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 47.44, 0.02], ["segm", "AP", 42.94, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..e5454bf --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,10 @@ +_BASE_: "../COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 47.34, 0.02], ["segm", "AP", 42.67, 0.02], ["bbox_TTA", "AP", 49.11, 0.02], ["segm_TTA", "AP", 45.04, 0.02]] + AUG: + ENABLED: True + MIN_SIZES: (700, 800) # to save some time diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..6dbfcde --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_instant_test.yaml @@ -0,0 +1,14 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_pred_boxes_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_pred_boxes_training_acc_test.yaml new file mode 100644 index 0000000..52f7876 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_pred_boxes_training_acc_test.yaml @@ -0,0 +1,6 @@ +_BASE_: "./mask_rcnn_R_50_FPN_training_acc_test.yaml" +MODEL: + ROI_BOX_HEAD: + TRAIN_ON_PRED_BOXES: True +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 42.6, 1.0], ["segm", "AP", 35.8, 0.8]] diff --git a/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_training_acc_test.yaml new file mode 100644 index 0000000..aadae4c --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/mask_rcnn_R_50_FPN_training_acc_test.yaml @@ -0,0 +1,21 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + MASK_ON: True +DATASETS: + TRAIN: ("coco_2017_val",) + TEST: ("coco_2017_val",) +INPUT: + MIN_SIZE_TRAIN: (600,) + MAX_SIZE_TRAIN: 1000 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1000 +SOLVER: + WARMUP_FACTOR: 0.3333333 + WARMUP_ITERS: 100 + STEPS: (5500, 5800) + MAX_ITER: 6000 +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 42.5, 1.0], ["segm", "AP", 35.8, 0.8]] diff --git a/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml new file mode 100644 index 0000000..70874e3 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-PanopticSegmentation/panoptic_fpn_R_50_3x/139514569/model_final_c10459.pkl" +DATASETS: + TEST: ("coco_2017_val_100_panoptic_separated",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 46.47, 0.02], ["segm", "AP", 43.39, 0.02], ["sem_seg", "mIoU", 42.55, 0.02], ["panoptic_seg", "PQ", 38.99, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_instant_test.yaml new file mode 100644 index 0000000..7cdee7b --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_instant_test.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "PanopticFPN" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + SEM_SEG_HEAD: + LOSS_WEIGHT: 0.5 +DATASETS: + TRAIN: ("coco_2017_val_100_panoptic_separated",) + TEST: ("coco_2017_val_100_panoptic_separated",) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 1 diff --git a/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_training_acc_test.yaml new file mode 100644 index 0000000..f3bbf30 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/panoptic_fpn_R_50_training_acc_test.yaml @@ -0,0 +1,20 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "PanopticFPN" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + SEM_SEG_HEAD: + LOSS_WEIGHT: 0.5 +DATASETS: + TRAIN: ("coco_2017_val_panoptic_separated",) + TEST: ("coco_2017_val_panoptic_separated",) +SOLVER: + BASE_LR: 0.01 + WARMUP_FACTOR: 0.001 + WARMUP_ITERS: 500 + STEPS: (5500,) + MAX_ITER: 7000 +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 46.70, 1.1], ["segm", "AP", 39.0, 0.7], ["sem_seg", "mIoU", 64.73, 1.3], ["panoptic_seg", "PQ", 48.13, 0.8]] diff --git a/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..cb666c1 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-Detection/retinanet_R_50_FPN_3x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-Detection/retinanet_R_50_FPN_3x/190397829/model_final_5bd44e.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 44.45, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..8d95c1f --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/retinanet_R_50_FPN_instant_test.yaml @@ -0,0 +1,13 @@ +_BASE_: "../COCO-Detection/retinanet_R_50_FPN_1x.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..c7c3f90 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,7 @@ +_BASE_: "../COCO-Detection/rpn_R_50_FPN_1x.yaml" +MODEL: + WEIGHTS: "detectron2://COCO-Detection/rpn_R_50_FPN_1x/137258492/model_final_02ce48.pkl" +DATASETS: + TEST: ("coco_2017_val_100",) +TEST: + EXPECTED_RESULTS: [["box_proposals", "AR@1000", 58.16, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..402d432 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/rpn_R_50_FPN_instant_test.yaml @@ -0,0 +1,13 @@ +_BASE_: "../COCO-Detection/rpn_R_50_FPN_1x.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +DATASETS: + TRAIN: ("coco_2017_val_100",) + TEST: ("coco_2017_val_100",) +SOLVER: + STEPS: (30,) + MAX_ITER: 40 + BASE_LR: 0.005 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..bca7498 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,10 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + WEIGHTS: "detectron2://semantic_R_50_FPN_1x/111802073/model_final_c18079783c55a94968edc28b7101c5f0.pkl" + RESNETS: + DEPTH: 50 +DATASETS: + TEST: ("coco_2017_val_100_panoptic_stuffonly",) +TEST: + EXPECTED_RESULTS: [["sem_seg", "mIoU", 39.53, 0.02], ["sem_seg", "mACC", 51.50, 0.02]] diff --git a/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_instant_test.yaml b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..14ab606 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_instant_test.yaml @@ -0,0 +1,18 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +DATASETS: + TRAIN: ("coco_2017_val_100_panoptic_stuffonly",) + TEST: ("coco_2017_val_100_panoptic_stuffonly",) +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +SOLVER: + BASE_LR: 0.005 + STEPS: (30,) + MAX_ITER: 40 + IMS_PER_BATCH: 4 +DATALOADER: + NUM_WORKERS: 2 diff --git a/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_training_acc_test.yaml b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_training_acc_test.yaml new file mode 100644 index 0000000..1f78d77 --- /dev/null +++ b/data_processing/detectron2/configs/quick_schedules/semantic_R_50_FPN_training_acc_test.yaml @@ -0,0 +1,20 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +DATASETS: + TRAIN: ("coco_2017_val_panoptic_stuffonly",) + TEST: ("coco_2017_val_panoptic_stuffonly",) +SOLVER: + BASE_LR: 0.01 + WARMUP_FACTOR: 0.001 + WARMUP_ITERS: 300 + STEPS: (5500,) + MAX_ITER: 7000 +TEST: + EXPECTED_RESULTS: [["sem_seg", "mIoU", 76.51, 1.0], ["sem_seg", "mACC", 83.25, 1.0]] +INPUT: + # no scale augmentation + MIN_SIZE_TRAIN: (800, ) diff --git a/data_processing/detectron2/datasets/README.md b/data_processing/detectron2/datasets/README.md new file mode 100644 index 0000000..0eb44cc --- /dev/null +++ b/data_processing/detectron2/datasets/README.md @@ -0,0 +1,140 @@ +# Use Builtin Datasets + +A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) +for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). +This document explains how to setup the builtin datasets so they can be used by the above APIs. +[Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, +and how to add new datasets to them. + +Detectron2 has builtin support for a few datasets. +The datasets are assumed to exist in a directory specified by the environment variable +`DETECTRON2_DATASETS`. +Under this directory, detectron2 will look for datasets in the structure described below, if needed. +``` +$DETECTRON2_DATASETS/ + coco/ + lvis/ + cityscapes/ + VOC20{07,12}/ +``` + +You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. +If left unset, the default is `./datasets` relative to your current working directory. + +The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md) +contains configs and models that use these builtin datasets. + +## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download): + +``` +coco/ + annotations/ + instances_{train,val}2017.json + person_keypoints_{train,val}2017.json + {train,val}2017/ + # image files that are mentioned in the corresponding json +``` + +You can use the 2014 version of the dataset as well. + +Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset, +which you can download with `./datasets/prepare_for_tests.sh`. + +## Expected dataset structure for PanopticFPN: + +Extract panoptic annotations from [COCO website](https://cocodataset.org/#download) +into the following structure: +``` +coco/ + annotations/ + panoptic_{train,val}2017.json + panoptic_{train,val}2017/ # png annotations + panoptic_stuff_{train,val}2017/ # generated by the script mentioned below +``` + +Install panopticapi by: +``` +pip install git+https://github.com/cocodataset/panopticapi.git +``` +Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations. + +## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset): +``` +coco/ + {train,val,test}2017/ +lvis/ + lvis_v0.5_{train,val}.json + lvis_v0.5_image_info_test.json + lvis_v1_{train,val}.json + lvis_v1_image_info_test{,_challenge}.json +``` + +Install lvis-api by: +``` +pip install git+https://github.com/lvis-dataset/lvis-api.git +``` + +To evaluate models trained on the COCO dataset using LVIS annotations, +run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations. + +## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/): +``` +cityscapes/ + gtFine/ + train/ + aachen/ + color.png, instanceIds.png, labelIds.png, polygons.json, + labelTrainIds.png + ... + val/ + test/ + # below are generated Cityscapes panoptic annotation + cityscapes_panoptic_train.json + cityscapes_panoptic_train/ + cityscapes_panoptic_val.json + cityscapes_panoptic_val/ + cityscapes_panoptic_test.json + cityscapes_panoptic_test/ + leftImg8bit/ + train/ + val/ + test/ +``` +Install cityscapes scripts by: +``` +pip install git+https://github.com/mcordts/cityscapesScripts.git +``` + +Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with: +``` +CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py +``` +These files are not needed for instance segmentation. + +Note: to generate Cityscapes panoptic dataset, run cityscapesescript with: +``` +CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py +``` +These files are not needed for semantic and instance segmentation. + +## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html): +``` +VOC20{07,12}/ + Annotations/ + ImageSets/ + Main/ + trainval.txt + test.txt + # train.txt or val.txt, if you use these splits + JPEGImages/ +``` + +## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/): +``` +ADEChallengeData2016/ + annotations/ + annotations_detectron2/ + images/ + objectInfo150.txt +``` +The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`. diff --git a/data_processing/detectron2/datasets/prepare_ade20k_sem_seg.py b/data_processing/detectron2/datasets/prepare_ade20k_sem_seg.py new file mode 100644 index 0000000..8b4a58d --- /dev/null +++ b/data_processing/detectron2/datasets/prepare_ade20k_sem_seg.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import os +from pathlib import Path +import tqdm +from PIL import Image + + +def convert(input, output): + img = np.asarray(Image.open(input)) + assert img.dtype == np.uint8 + img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 + Image.fromarray(img).save(output) + + +if __name__ == "__main__": + dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016" + for name in ["training", "validation"]: + annotation_dir = dataset_dir / "annotations" / name + output_dir = dataset_dir / "annotations_detectron2" / name + output_dir.mkdir(parents=True, exist_ok=True) + for file in tqdm.tqdm(list(annotation_dir.iterdir())): + output_file = output_dir / file.name + convert(file, output_file) diff --git a/data_processing/detectron2/datasets/prepare_cocofied_lvis.py b/data_processing/detectron2/datasets/prepare_cocofied_lvis.py new file mode 100644 index 0000000..245c884 --- /dev/null +++ b/data_processing/detectron2/datasets/prepare_cocofied_lvis.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import json +import os +from collections import defaultdict + +# This mapping is extracted from the official LVIS mapping: +# https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json +COCO_SYNSET_CATEGORIES = [ + {"synset": "person.n.01", "coco_cat_id": 1}, + {"synset": "bicycle.n.01", "coco_cat_id": 2}, + {"synset": "car.n.01", "coco_cat_id": 3}, + {"synset": "motorcycle.n.01", "coco_cat_id": 4}, + {"synset": "airplane.n.01", "coco_cat_id": 5}, + {"synset": "bus.n.01", "coco_cat_id": 6}, + {"synset": "train.n.01", "coco_cat_id": 7}, + {"synset": "truck.n.01", "coco_cat_id": 8}, + {"synset": "boat.n.01", "coco_cat_id": 9}, + {"synset": "traffic_light.n.01", "coco_cat_id": 10}, + {"synset": "fireplug.n.01", "coco_cat_id": 11}, + {"synset": "stop_sign.n.01", "coco_cat_id": 13}, + {"synset": "parking_meter.n.01", "coco_cat_id": 14}, + {"synset": "bench.n.01", "coco_cat_id": 15}, + {"synset": "bird.n.01", "coco_cat_id": 16}, + {"synset": "cat.n.01", "coco_cat_id": 17}, + {"synset": "dog.n.01", "coco_cat_id": 18}, + {"synset": "horse.n.01", "coco_cat_id": 19}, + {"synset": "sheep.n.01", "coco_cat_id": 20}, + {"synset": "beef.n.01", "coco_cat_id": 21}, + {"synset": "elephant.n.01", "coco_cat_id": 22}, + {"synset": "bear.n.01", "coco_cat_id": 23}, + {"synset": "zebra.n.01", "coco_cat_id": 24}, + {"synset": "giraffe.n.01", "coco_cat_id": 25}, + {"synset": "backpack.n.01", "coco_cat_id": 27}, + {"synset": "umbrella.n.01", "coco_cat_id": 28}, + {"synset": "bag.n.04", "coco_cat_id": 31}, + {"synset": "necktie.n.01", "coco_cat_id": 32}, + {"synset": "bag.n.06", "coco_cat_id": 33}, + {"synset": "frisbee.n.01", "coco_cat_id": 34}, + {"synset": "ski.n.01", "coco_cat_id": 35}, + {"synset": "snowboard.n.01", "coco_cat_id": 36}, + {"synset": "ball.n.06", "coco_cat_id": 37}, + {"synset": "kite.n.03", "coco_cat_id": 38}, + {"synset": "baseball_bat.n.01", "coco_cat_id": 39}, + {"synset": "baseball_glove.n.01", "coco_cat_id": 40}, + {"synset": "skateboard.n.01", "coco_cat_id": 41}, + {"synset": "surfboard.n.01", "coco_cat_id": 42}, + {"synset": "tennis_racket.n.01", "coco_cat_id": 43}, + {"synset": "bottle.n.01", "coco_cat_id": 44}, + {"synset": "wineglass.n.01", "coco_cat_id": 46}, + {"synset": "cup.n.01", "coco_cat_id": 47}, + {"synset": "fork.n.01", "coco_cat_id": 48}, + {"synset": "knife.n.01", "coco_cat_id": 49}, + {"synset": "spoon.n.01", "coco_cat_id": 50}, + {"synset": "bowl.n.03", "coco_cat_id": 51}, + {"synset": "banana.n.02", "coco_cat_id": 52}, + {"synset": "apple.n.01", "coco_cat_id": 53}, + {"synset": "sandwich.n.01", "coco_cat_id": 54}, + {"synset": "orange.n.01", "coco_cat_id": 55}, + {"synset": "broccoli.n.01", "coco_cat_id": 56}, + {"synset": "carrot.n.01", "coco_cat_id": 57}, + {"synset": "frank.n.02", "coco_cat_id": 58}, + {"synset": "pizza.n.01", "coco_cat_id": 59}, + {"synset": "doughnut.n.02", "coco_cat_id": 60}, + {"synset": "cake.n.03", "coco_cat_id": 61}, + {"synset": "chair.n.01", "coco_cat_id": 62}, + {"synset": "sofa.n.01", "coco_cat_id": 63}, + {"synset": "pot.n.04", "coco_cat_id": 64}, + {"synset": "bed.n.01", "coco_cat_id": 65}, + {"synset": "dining_table.n.01", "coco_cat_id": 67}, + {"synset": "toilet.n.02", "coco_cat_id": 70}, + {"synset": "television_receiver.n.01", "coco_cat_id": 72}, + {"synset": "laptop.n.01", "coco_cat_id": 73}, + {"synset": "mouse.n.04", "coco_cat_id": 74}, + {"synset": "remote_control.n.01", "coco_cat_id": 75}, + {"synset": "computer_keyboard.n.01", "coco_cat_id": 76}, + {"synset": "cellular_telephone.n.01", "coco_cat_id": 77}, + {"synset": "microwave.n.02", "coco_cat_id": 78}, + {"synset": "oven.n.01", "coco_cat_id": 79}, + {"synset": "toaster.n.02", "coco_cat_id": 80}, + {"synset": "sink.n.01", "coco_cat_id": 81}, + {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82}, + {"synset": "book.n.01", "coco_cat_id": 84}, + {"synset": "clock.n.01", "coco_cat_id": 85}, + {"synset": "vase.n.01", "coco_cat_id": 86}, + {"synset": "scissors.n.01", "coco_cat_id": 87}, + {"synset": "teddy.n.01", "coco_cat_id": 88}, + {"synset": "hand_blower.n.01", "coco_cat_id": 89}, + {"synset": "toothbrush.n.01", "coco_cat_id": 90}, +] + + +def cocofy_lvis(input_filename, output_filename): + """ + Filter LVIS instance segmentation annotations to remove all categories that are not included in + COCO. The new json files can be used to evaluate COCO AP using `lvis-api`. The category ids in + the output json are the incontiguous COCO dataset ids. + + Args: + input_filename (str): path to the LVIS json file. + output_filename (str): path to the COCOfied json file. + """ + + with open(input_filename, "r") as f: + lvis_json = json.load(f) + + lvis_annos = lvis_json.pop("annotations") + cocofied_lvis = copy.deepcopy(lvis_json) + lvis_json["annotations"] = lvis_annos + + # Mapping from lvis cat id to coco cat id via synset + lvis_cat_id_to_synset = {cat["id"]: cat["synset"] for cat in lvis_json["categories"]} + synset_to_coco_cat_id = {x["synset"]: x["coco_cat_id"] for x in COCO_SYNSET_CATEGORIES} + # Synsets that we will keep in the dataset + synsets_to_keep = set(synset_to_coco_cat_id.keys()) + coco_cat_id_with_instances = defaultdict(int) + + new_annos = [] + ann_id = 1 + for ann in lvis_annos: + lvis_cat_id = ann["category_id"] + synset = lvis_cat_id_to_synset[lvis_cat_id] + if synset not in synsets_to_keep: + continue + coco_cat_id = synset_to_coco_cat_id[synset] + new_ann = copy.deepcopy(ann) + new_ann["category_id"] = coco_cat_id + new_ann["id"] = ann_id + ann_id += 1 + new_annos.append(new_ann) + coco_cat_id_with_instances[coco_cat_id] += 1 + cocofied_lvis["annotations"] = new_annos + + for image in cocofied_lvis["images"]: + for key in ["not_exhaustive_category_ids", "neg_category_ids"]: + new_category_list = [] + for lvis_cat_id in image[key]: + synset = lvis_cat_id_to_synset[lvis_cat_id] + if synset not in synsets_to_keep: + continue + coco_cat_id = synset_to_coco_cat_id[synset] + new_category_list.append(coco_cat_id) + coco_cat_id_with_instances[coco_cat_id] += 1 + image[key] = new_category_list + + coco_cat_id_with_instances = set(coco_cat_id_with_instances.keys()) + + new_categories = [] + for cat in lvis_json["categories"]: + synset = cat["synset"] + if synset not in synsets_to_keep: + continue + coco_cat_id = synset_to_coco_cat_id[synset] + if coco_cat_id not in coco_cat_id_with_instances: + continue + new_cat = copy.deepcopy(cat) + new_cat["id"] = coco_cat_id + new_categories.append(new_cat) + cocofied_lvis["categories"] = new_categories + + with open(output_filename, "w") as f: + json.dump(cocofied_lvis, f) + print("{} is COCOfied and stored in {}.".format(input_filename, output_filename)) + + +if __name__ == "__main__": + dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "lvis") + for s in ["lvis_v0.5_train", "lvis_v0.5_val"]: + print("Start COCOfing {}.".format(s)) + cocofy_lvis( + os.path.join(dataset_dir, "{}.json".format(s)), + os.path.join(dataset_dir, "{}_cocofied.json".format(s)), + ) diff --git a/data_processing/detectron2/datasets/prepare_for_tests.sh b/data_processing/detectron2/datasets/prepare_for_tests.sh new file mode 100644 index 0000000..67e875a --- /dev/null +++ b/data_processing/detectron2/datasets/prepare_for_tests.sh @@ -0,0 +1,31 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +# Download the mini dataset (coco val2017_100, with only 100 images) +# to be used in unittests & integration tests. + +cd "${0%/*}" + +BASE=https://dl.fbaipublicfiles.com/detectron2 +ROOT=${DETECTRON2_DATASETS:-./} +ROOT=${ROOT/#\~/$HOME} # expand ~ to HOME +mkdir -p $ROOT/coco/annotations + +for anno in instances_val2017_100 \ + person_keypoints_val2017_100 ; do + + dest=$ROOT/coco/annotations/$anno.json + [[ -s $dest ]] && { + echo "$dest exists. Skipping ..." + } || { + wget $BASE/annotations/coco/$anno.json -O $dest + } +done + +dest=$ROOT/coco/val2017_100.tgz +[[ -d $ROOT/coco/val2017 ]] && { + echo "$ROOT/coco/val2017 exists. Skipping ..." +} || { + wget $BASE/annotations/coco/val2017_100.tgz -O $dest + tar xzf $dest -C $ROOT/coco/ && rm -f $dest +} diff --git a/data_processing/detectron2/datasets/prepare_panoptic_fpn.py b/data_processing/detectron2/datasets/prepare_panoptic_fpn.py new file mode 100644 index 0000000..597d791 --- /dev/null +++ b/data_processing/detectron2/datasets/prepare_panoptic_fpn.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import json +import multiprocessing as mp +import numpy as np +import os +import time +from fvcore.common.download import download +from panopticapi.utils import rgb2id +from PIL import Image + +from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES + + +def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map): + panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32) + panoptic = rgb2id(panoptic) + output = np.zeros_like(panoptic, dtype=np.uint8) + 255 + for seg in segments: + cat_id = seg["category_id"] + new_cat_id = id_map[cat_id] + output[panoptic == seg["id"]] = new_cat_id + Image.fromarray(output).save(output_semantic) + + +def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, sem_seg_root, categories): + """ + Create semantic segmentation annotations from panoptic segmentation + annotations, to be used by PanopticFPN. + + It maps all thing categories to class 0, and maps all unlabeled pixels to class 255. + It maps all stuff categories to contiguous ids starting from 1. + + Args: + panoptic_json (str): path to the panoptic json file, in COCO's format. + panoptic_root (str): a directory with panoptic annotation files, in COCO's format. + sem_seg_root (str): a directory to output semantic annotation files + categories (list[dict]): category metadata. Each dict needs to have: + "id": corresponds to the "category_id" in the json annotations + "isthing": 0 or 1 + """ + os.makedirs(sem_seg_root, exist_ok=True) + + stuff_ids = [k["id"] for k in categories if k["isthing"] == 0] + thing_ids = [k["id"] for k in categories if k["isthing"] == 1] + id_map = {} # map from category id to id in the output semantic annotation + assert len(stuff_ids) <= 254 + for i, stuff_id in enumerate(stuff_ids): + id_map[stuff_id] = i + 1 + for thing_id in thing_ids: + id_map[thing_id] = 0 + id_map[0] = 255 + + with open(panoptic_json) as f: + obj = json.load(f) + + pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) + + def iter_annotations(): + for anno in obj["annotations"]: + file_name = anno["file_name"] + segments = anno["segments_info"] + input = os.path.join(panoptic_root, file_name) + output = os.path.join(sem_seg_root, file_name) + yield input, output, segments + + print("Start writing to {} ...".format(sem_seg_root)) + start = time.time() + pool.starmap( + functools.partial(_process_panoptic_to_semantic, id_map=id_map), + iter_annotations(), + chunksize=100, + ) + print("Finished. time: {:.2f}s".format(time.time() - start)) + + +if __name__ == "__main__": + dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco") + for s in ["val2017", "train2017"]: + separate_coco_semantic_from_panoptic( + os.path.join(dataset_dir, "annotations/panoptic_{}.json".format(s)), + os.path.join(dataset_dir, "panoptic_{}".format(s)), + os.path.join(dataset_dir, "panoptic_stuff_{}".format(s)), + COCO_CATEGORIES, + ) + + # Prepare val2017_100 for quick testing: + + dest_dir = os.path.join(dataset_dir, "annotations/") + URL_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" + download(URL_PREFIX + "annotations/coco/panoptic_val2017_100.json", dest_dir) + with open(os.path.join(dest_dir, "panoptic_val2017_100.json")) as f: + obj = json.load(f) + + def link_val100(dir_full, dir_100): + print("Creating " + dir_100 + " ...") + os.makedirs(dir_100, exist_ok=True) + for img in obj["images"]: + basename = os.path.splitext(img["file_name"])[0] + src = os.path.join(dir_full, basename + ".png") + dst = os.path.join(dir_100, basename + ".png") + src = os.path.relpath(src, start=dir_100) + os.symlink(src, dst) + + link_val100( + os.path.join(dataset_dir, "panoptic_val2017"), + os.path.join(dataset_dir, "panoptic_val2017_100"), + ) + + link_val100( + os.path.join(dataset_dir, "panoptic_stuff_val2017"), + os.path.join(dataset_dir, "panoptic_stuff_val2017_100"), + ) diff --git a/data_processing/detectron2/demo/README.md b/data_processing/detectron2/demo/README.md new file mode 100644 index 0000000..133d8d3 --- /dev/null +++ b/data_processing/detectron2/demo/README.md @@ -0,0 +1,8 @@ + +## Detectron2 Demo + +We provide a command line tool to run a simple demo of builtin configs. +The usage is explained in [GETTING_STARTED.md](../GETTING_STARTED.md). + +See our [blog post](https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-) +for a high-quality demo generated with this tool. diff --git a/data_processing/detectron2/demo/demo.py b/data_processing/detectron2/demo/demo.py new file mode 100644 index 0000000..4baa876 --- /dev/null +++ b/data_processing/detectron2/demo/demo.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import glob +import multiprocessing as mp +import numpy as np +import os +import tempfile +import time +import warnings +import cv2 +import tqdm + +from detectron2.config import get_cfg +from detectron2.data.detection_utils import read_image +from detectron2.utils.logger import setup_logger + +from predictor import VisualizationDemo + +# constants +WINDOW_NAME = "COCO detections" + + +def setup_cfg(args): + # load config from file and command-line arguments + cfg = get_cfg() + # To use demo for Panoptic-DeepLab, please uncomment the following two lines. + # from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa + # add_panoptic_deeplab_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + # Set score_threshold for builtin models + cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold + cfg.freeze() + return cfg + + +def get_parser(): + parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") + parser.add_argument( + "--config-file", + default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") + parser.add_argument("--video-input", help="Path to video file.") + parser.add_argument( + "--input", + nargs="+", + help="A list of space separated input images; " + "or a single glob pattern such as 'directory/*.jpg'", + ) + parser.add_argument( + "--output", + help="A file or directory to save output visualizations. " + "If not given, will show output in an OpenCV window.", + ) + + parser.add_argument( + "--confidence-threshold", + type=float, + default=0.5, + help="Minimum score for instance predictions to be shown", + ) + parser.add_argument( + "--opts", + help="Modify config options using the command-line 'KEY VALUE' pairs", + default=[], + nargs=argparse.REMAINDER, + ) + return parser + + +def test_opencv_video_format(codec, file_ext): + with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: + filename = os.path.join(dir, "test_file" + file_ext) + writer = cv2.VideoWriter( + filename=filename, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(30), + frameSize=(10, 10), + isColor=True, + ) + [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] + writer.release() + if os.path.isfile(filename): + return True + return False + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + args = get_parser().parse_args() + setup_logger(name="fvcore") + logger = setup_logger() + logger.info("Arguments: " + str(args)) + + cfg = setup_cfg(args) + + demo = VisualizationDemo(cfg) + + if args.input: + if len(args.input) == 1: + args.input = glob.glob(os.path.expanduser(args.input[0])) + assert args.input, "The input path(s) was not found" + for path in tqdm.tqdm(args.input, disable=not args.output): + # use PIL, to be consistent with evaluation + img = read_image(path, format="BGR") + start_time = time.time() + predictions, visualized_output = demo.run_on_image(img) + logger.info( + "{}: {} in {:.2f}s".format( + path, + "detected {} instances".format(len(predictions["instances"])) + if "instances" in predictions + else "finished", + time.time() - start_time, + ) + ) + + if args.output: + if os.path.isdir(args.output): + assert os.path.isdir(args.output), args.output + out_filename = os.path.join(args.output, os.path.basename(path)) + else: + assert len(args.input) == 1, "Please specify a directory with args.output" + out_filename = args.output + visualized_output.save(out_filename) + else: + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) + if cv2.waitKey(0) == 27: + break # esc to quit + elif args.webcam: + assert args.input is None, "Cannot have both --input and --webcam!" + assert args.output is None, "output not yet supported with --webcam!" + cam = cv2.VideoCapture(0) + for vis in tqdm.tqdm(demo.run_on_video(cam)): + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, vis) + if cv2.waitKey(1) == 27: + break # esc to quit + cam.release() + cv2.destroyAllWindows() + elif args.video_input: + video = cv2.VideoCapture(args.video_input) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames_per_second = video.get(cv2.CAP_PROP_FPS) + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + basename = os.path.basename(args.video_input) + codec, file_ext = ( + ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") + ) + if codec == ".mp4v": + warnings.warn("x264 codec not available, switching to mp4v") + if args.output: + if os.path.isdir(args.output): + output_fname = os.path.join(args.output, basename) + output_fname = os.path.splitext(output_fname)[0] + file_ext + else: + output_fname = args.output + assert not os.path.isfile(output_fname), output_fname + output_file = cv2.VideoWriter( + filename=output_fname, + # some installation of opencv may not support x264 (due to its license), + # you can try other format (e.g. MPEG) + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(frames_per_second), + frameSize=(width, height), + isColor=True, + ) + assert os.path.isfile(args.video_input) + for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): + if args.output: + output_file.write(vis_frame) + else: + cv2.namedWindow(basename, cv2.WINDOW_NORMAL) + cv2.imshow(basename, vis_frame) + if cv2.waitKey(1) == 27: + break # esc to quit + video.release() + if args.output: + output_file.release() + else: + cv2.destroyAllWindows() diff --git a/data_processing/detectron2/demo/predictor.py b/data_processing/detectron2/demo/predictor.py new file mode 100644 index 0000000..7b7ebd3 --- /dev/null +++ b/data_processing/detectron2/demo/predictor.py @@ -0,0 +1,220 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import atexit +import bisect +import multiprocessing as mp +from collections import deque +import cv2 +import torch + +from detectron2.data import MetadataCatalog +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.video_visualizer import VideoVisualizer +from detectron2.utils.visualizer import ColorMode, Visualizer + + +class VisualizationDemo(object): + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + self.metadata = MetadataCatalog.get( + cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" + ) + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + num_gpu = torch.cuda.device_count() + self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) + else: + self.predictor = DefaultPredictor(cfg) + + def run_on_image(self, image): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + vis_output = None + predictions = self.predictor(image) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_output = visualizer.draw_panoptic_seg_predictions( + panoptic_seg.to(self.cpu_device), segments_info + ) + else: + if "sem_seg" in predictions: + vis_output = visualizer.draw_sem_seg( + predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions(predictions=instances) + + return predictions, vis_output + + def _frame_from_video(self, video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + + def run_on_video(self, video): + """ + Visualizes predictions on frames of the input video. + + Args: + video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be + either a webcam or a video file. + + Yields: + ndarray: BGR visualizations of each video frame. + """ + video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) + + def process_predictions(frame, predictions): + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_frame = video_visualizer.draw_panoptic_seg_predictions( + frame, panoptic_seg.to(self.cpu_device), segments_info + ) + elif "instances" in predictions: + predictions = predictions["instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + elif "sem_seg" in predictions: + vis_frame = video_visualizer.draw_sem_seg( + frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + + # Converts Matplotlib RGB format to OpenCV BGR format + vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) + return vis_frame + + frame_gen = self._frame_from_video(video) + if self.parallel: + buffer_size = self.predictor.default_buffer_size + + frame_data = deque() + + for cnt, frame in enumerate(frame_gen): + frame_data.append(frame) + self.predictor.put(frame) + + if cnt >= buffer_size: + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + + while len(frame_data): + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + else: + for frame in frame_gen: + yield process_predictions(frame, self.predictor(frame)) + + +class AsyncPredictor: + """ + A predictor that runs the model asynchronously, possibly on >1 GPUs. + Because rendering the visualization takes considerably amount of time, + this helps improve throughput a little bit when rendering videos. + """ + + class _StopToken: + pass + + class _PredictWorker(mp.Process): + def __init__(self, cfg, task_queue, result_queue): + self.cfg = cfg + self.task_queue = task_queue + self.result_queue = result_queue + super().__init__() + + def run(self): + predictor = DefaultPredictor(self.cfg) + + while True: + task = self.task_queue.get() + if isinstance(task, AsyncPredictor._StopToken): + break + idx, data = task + result = predictor(data) + self.result_queue.put((idx, result)) + + def __init__(self, cfg, num_gpus: int = 1): + """ + Args: + cfg (CfgNode): + num_gpus (int): if 0, will run on CPU + """ + num_workers = max(num_gpus, 1) + self.task_queue = mp.Queue(maxsize=num_workers * 3) + self.result_queue = mp.Queue(maxsize=num_workers * 3) + self.procs = [] + for gpuid in range(max(num_gpus, 1)): + cfg = cfg.clone() + cfg.defrost() + cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" + self.procs.append( + AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) + ) + + self.put_idx = 0 + self.get_idx = 0 + self.result_rank = [] + self.result_data = [] + + for p in self.procs: + p.start() + atexit.register(self.shutdown) + + def put(self, image): + self.put_idx += 1 + self.task_queue.put((self.put_idx, image)) + + def get(self): + self.get_idx += 1 # the index needed for this request + if len(self.result_rank) and self.result_rank[0] == self.get_idx: + res = self.result_data[0] + del self.result_data[0], self.result_rank[0] + return res + + while True: + # make sure the results are returned in the correct order + idx, res = self.result_queue.get() + if idx == self.get_idx: + return res + insert = bisect.bisect(self.result_rank, idx) + self.result_rank.insert(insert, idx) + self.result_data.insert(insert, res) + + def __len__(self): + return self.put_idx - self.get_idx + + def __call__(self, image): + self.put(image) + return self.get() + + def shutdown(self): + for _ in self.procs: + self.task_queue.put(AsyncPredictor._StopToken()) + + @property + def default_buffer_size(self): + return len(self.procs) * 5 diff --git a/data_processing/detectron2/detectron2/__init__.py b/data_processing/detectron2/detectron2/__init__.py new file mode 100644 index 0000000..bdd994b --- /dev/null +++ b/data_processing/detectron2/detectron2/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .utils.env import setup_environment + +setup_environment() + + +# This line will be programatically read/write by setup.py. +# Leave them at the bottom of this file and don't touch them. +__version__ = "0.6" diff --git a/data_processing/detectron2/detectron2/checkpoint/__init__.py b/data_processing/detectron2/detectron2/checkpoint/__init__.py new file mode 100644 index 0000000..99da046 --- /dev/null +++ b/data_processing/detectron2/detectron2/checkpoint/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# File: + + +from . import catalog as _UNUSED # register the handler +from .detection_checkpoint import DetectionCheckpointer +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"] diff --git a/data_processing/detectron2/detectron2/checkpoint/c2_model_loading.py b/data_processing/detectron2/detectron2/checkpoint/c2_model_loading.py new file mode 100644 index 0000000..c6de2a3 --- /dev/null +++ b/data_processing/detectron2/detectron2/checkpoint/c2_model_loading.py @@ -0,0 +1,412 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import re +from typing import Dict, List +import torch +from tabulate import tabulate + + +def convert_basic_c2_names(original_keys): + """ + Apply some basic name conversion to names in C2 weights. + It only deals with typical backbone models. + + Args: + original_keys (list[str]): + Returns: + list[str]: The same number of strings matching those in original_keys. + """ + layer_keys = copy.deepcopy(original_keys) + layer_keys = [ + {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys + ] # some hard-coded mappings + + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys] + layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys] + # Uniform both bn and gn names to "norm" + layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys] + + # stem + layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys] + # to avoid mis-matching with "conv1" in other components (e.g. detection head) + layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys] + + # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5) + # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys] + # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys] + # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys] + # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys] + + # blocks + layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys] + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + + # DensePose substitutions + layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys] + layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys] + layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys] + layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys] + layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys] + return layer_keys + + +def convert_c2_detectron_names(weights): + """ + Map Caffe2 Detectron weight names to Detectron2 names. + + Args: + weights (dict): name -> tensor + + Returns: + dict: detectron2 names -> tensor + dict: detectron2 names -> C2 names + """ + logger = logging.getLogger(__name__) + logger.info("Renaming Caffe2 weights ......") + original_keys = sorted(weights.keys()) + layer_keys = copy.deepcopy(original_keys) + + layer_keys = convert_basic_c2_names(layer_keys) + + # -------------------------------------------------------------------------- + # RPN hidden representation conv + # -------------------------------------------------------------------------- + # FPN case + # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then + # shared for all other levels, hence the appearance of "fpn2" + layer_keys = [ + k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys + ] + # Non-FPN case + layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # RPN box transformation conv + # -------------------------------------------------------------------------- + # FPN case (see note above about "fpn2") + layer_keys = [ + k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas") + for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + # Non-FPN case + layer_keys = [ + k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + + # -------------------------------------------------------------------------- + # Fast R-CNN box head + # -------------------------------------------------------------------------- + layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys] + layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys] + layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys] + layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys] + # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s + layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # FPN lateral and output convolutions + # -------------------------------------------------------------------------- + def fpn_map(name): + """ + Look for keys with the following patterns: + 1) Starts with "fpn.inner." + Example: "fpn.inner.res2.2.sum.lateral.weight" + Meaning: These are lateral pathway convolutions + 2) Starts with "fpn.res" + Example: "fpn.res2.2.sum.weight" + Meaning: These are FPN output convolutions + """ + splits = name.split(".") + norm = ".norm" if "norm" in splits else "" + if name.startswith("fpn.inner."): + # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight'] + stage = int(splits[2][len("res") :]) + return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1]) + elif name.startswith("fpn.res"): + # splits example: ['fpn', 'res2', '2', 'sum', 'weight'] + stage = int(splits[1][len("res") :]) + return "fpn_output{}{}.{}".format(stage, norm, splits[-1]) + return name + + layer_keys = [fpn_map(k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # Mask R-CNN mask head + # -------------------------------------------------------------------------- + # roi_heads.StandardROIHeads case + layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys] + layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys] + layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys] + # roi_heads.Res5ROIHeads case + layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Keypoint R-CNN head + # -------------------------------------------------------------------------- + # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX" + layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys] + layer_keys = [ + k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys + ] + layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Done with replacements + # -------------------------------------------------------------------------- + assert len(set(layer_keys)) == len(layer_keys) + assert len(original_keys) == len(layer_keys) + + new_weights = {} + new_keys_to_original_keys = {} + for orig, renamed in zip(original_keys, layer_keys): + new_keys_to_original_keys[renamed] = orig + if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."): + # remove the meaningless prediction weight for background class + new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1 + new_weights[renamed] = weights[orig][new_start_idx:] + logger.info( + "Remove prediction weight for background class in {}. The shape changes from " + "{} to {}.".format( + renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape) + ) + ) + elif renamed.startswith("cls_score."): + # move weights of bg class from original index 0 to last index + logger.info( + "Move classification weights for background class in {} from index 0 to " + "index {}.".format(renamed, weights[orig].shape[0] - 1) + ) + new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]]) + else: + new_weights[renamed] = weights[orig] + + return new_weights, new_keys_to_original_keys + + +# Note the current matching is not symmetric. +# it assumes model_state_dict will have longer names. +def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True): + """ + Match names between the two state-dict, and returns a new chkpt_state_dict with names + converted to match model_state_dict with heuristics. The returned dict can be later + loaded with fvcore checkpointer. + If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2 + model and will be renamed at first. + + Strategy: suppose that the models that we will create will have prefixes appended + to each of its keys, for example due to an extra level of nesting that the original + pre-trained weights from ImageNet won't contain. For example, model.state_dict() + might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains + res2.conv1.weight. We thus want to match both parameters together. + For that, we look for each model weight, look among all loaded keys if there is one + that is a suffix of the current weight name, and use it if that's the case. + If multiple matches exist, take the one with longest size + of the corresponding name. For example, for the same model as before, the pretrained + weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, + we want to match backbone[0].body.conv1.weight to conv1.weight, and + backbone[0].body.res2.conv1.weight to res2.conv1.weight. + """ + model_keys = sorted(model_state_dict.keys()) + if c2_conversion: + ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict) + # original_keys: the name in the original dict (before renaming) + else: + original_keys = {x: x for x in ckpt_state_dict.keys()} + ckpt_keys = sorted(ckpt_state_dict.keys()) + + def match(a, b): + # Matched ckpt_key should be a complete (starts with '.') suffix. + # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1, + # but matches whatever_conv1 or mesh_head.whatever_conv1. + return a == b or a.endswith("." + b) + + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # ckpt_key string, if it matches + match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys] + match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys)) + # use the matched one with longest size in case of multiple matches + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + logger = logging.getLogger(__name__) + # matched_pairs (matched checkpoint key --> matched model key) + matched_keys = {} + result_state_dict = {} + for idx_model, idx_ckpt in enumerate(idxs.tolist()): + if idx_ckpt == -1: + continue + key_model = model_keys[idx_model] + key_ckpt = ckpt_keys[idx_ckpt] + value_ckpt = ckpt_state_dict[key_ckpt] + shape_in_model = model_state_dict[key_model].shape + + if shape_in_model != value_ckpt.shape: + logger.warning( + "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format( + key_ckpt, value_ckpt.shape, key_model, shape_in_model + ) + ) + logger.warning( + "{} will not be loaded. Please double check and see if this is desired.".format( + key_ckpt + ) + ) + continue + + assert key_model not in result_state_dict + result_state_dict[key_model] = value_ckpt + if key_ckpt in matched_keys: # already added to matched_keys + logger.error( + "Ambiguity found for {} in checkpoint!" + "It matches at least two keys in the model ({} and {}).".format( + key_ckpt, key_model, matched_keys[key_ckpt] + ) + ) + raise ValueError("Cannot match one checkpoint key to multiple keys in the model.") + + matched_keys[key_ckpt] = key_model + + # logging: + matched_model_keys = sorted(matched_keys.values()) + if len(matched_model_keys) == 0: + logger.warning("No weights in checkpoint matched with model.") + return ckpt_state_dict + common_prefix = _longest_common_prefix(matched_model_keys) + rev_matched_keys = {v: k for k, v in matched_keys.items()} + original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys} + + model_key_groups = _group_keys_by_module(matched_model_keys, original_keys) + table = [] + memo = set() + for key_model in matched_model_keys: + if key_model in memo: + continue + if key_model in model_key_groups: + group = model_key_groups[key_model] + memo |= set(group) + shapes = [tuple(model_state_dict[k].shape) for k in group] + table.append( + ( + _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*", + _group_str([original_keys[k] for k in group]), + " ".join([str(x).replace(" ", "") for x in shapes]), + ) + ) + else: + key_checkpoint = original_keys[key_model] + shape = str(tuple(model_state_dict[key_model].shape)) + table.append((key_model[len(common_prefix) :], key_checkpoint, shape)) + table_str = tabulate( + table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"] + ) + logger.info( + "Following weights matched with " + + (f"submodule {common_prefix[:-1]}" if common_prefix else "model") + + ":\n" + + table_str + ) + + unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())] + for k in unmatched_ckpt_keys: + result_state_dict[k] = ckpt_state_dict[k] + return result_state_dict + + +def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]): + """ + Params in the same submodule are grouped together. + + Args: + keys: names of all parameters + original_names: mapping from parameter name to their name in the checkpoint + + Returns: + dict[name -> all other names in the same group] + """ + + def _submodule_name(key): + pos = key.rfind(".") + if pos < 0: + return None + prefix = key[: pos + 1] + return prefix + + all_submodules = [_submodule_name(k) for k in keys] + all_submodules = [x for x in all_submodules if x] + all_submodules = sorted(all_submodules, key=len) + + ret = {} + for prefix in all_submodules: + group = [k for k in keys if k.startswith(prefix)] + if len(group) <= 1: + continue + original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group]) + if len(original_name_lcp) == 0: + # don't group weights if original names don't share prefix + continue + + for k in group: + if k in ret: + continue + ret[k] = group + return ret + + +def _longest_common_prefix(names: List[str]) -> str: + """ + ["abc.zfg", "abc.zef"] -> "abc." + """ + names = [n.split(".") for n in names] + m1, m2 = min(names), max(names) + ret = [a for a, b in zip(m1, m2) if a == b] + ret = ".".join(ret) + "." if len(ret) else "" + return ret + + +def _longest_common_prefix_str(names: List[str]) -> str: + m1, m2 = min(names), max(names) + lcp = [] + for a, b in zip(m1, m2): + if a == b: + lcp.append(a) + else: + break + lcp = "".join(lcp) + return lcp + + +def _group_str(names: List[str]) -> str: + """ + Turn "common1", "common2", "common3" into "common{1,2,3}" + """ + lcp = _longest_common_prefix_str(names) + rest = [x[len(lcp) :] for x in names] + rest = "{" + ",".join(rest) + "}" + ret = lcp + rest + + # add some simplification for BN specifically + ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*") + ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*") + return ret diff --git a/data_processing/detectron2/detectron2/checkpoint/catalog.py b/data_processing/detectron2/detectron2/checkpoint/catalog.py new file mode 100644 index 0000000..9a85736 --- /dev/null +++ b/data_processing/detectron2/detectron2/checkpoint/catalog.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging + +from detectron2.utils.file_io import PathHandler, PathManager + + +class ModelCatalog(object): + """ + Store mappings from names to third-party models. + """ + + S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron" + + # MSRA models have STRIDE_IN_1X1=True. False otherwise. + # NOTE: all BN models here have fused BN into an affine layer. + # As a result, you should only load them to a model with "FrozenBN". + # Loading them to a model with regular BN or SyncBN is wrong. + # Even when loaded to FrozenBN, it is still different from affine by an epsilon, + # which should be negligible for training. + # NOTE: all models here uses PIXEL_STD=[1,1,1] + # NOTE: Most of the BN models here are no longer used. We use the + # re-converted pre-trained models under detectron2 model zoo instead. + C2_IMAGENET_MODELS = { + "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl", + "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl", + "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl", + "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl", + "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl", + "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl", + "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl", + } + + C2_DETECTRON_PATH_FORMAT = ( + "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950 + ) + + C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival" + C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival" + + # format: {model_name} -> part of the url + C2_DETECTRON_MODELS = { + "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950 + "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950 + "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950 + "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950 + "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950 + "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950 + "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950 + "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950 + "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950 + "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950 + "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950 + "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950 + "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950 + } + + @staticmethod + def get(name): + if name.startswith("Caffe2Detectron/COCO"): + return ModelCatalog._get_c2_detectron_baseline(name) + if name.startswith("ImageNetPretrained/"): + return ModelCatalog._get_c2_imagenet_pretrained(name) + raise RuntimeError("model not present in the catalog: {}".format(name)) + + @staticmethod + def _get_c2_imagenet_pretrained(name): + prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX + name = name[len("ImageNetPretrained/") :] + name = ModelCatalog.C2_IMAGENET_MODELS[name] + url = "/".join([prefix, name]) + return url + + @staticmethod + def _get_c2_detectron_baseline(name): + name = name[len("Caffe2Detectron/COCO/") :] + url = ModelCatalog.C2_DETECTRON_MODELS[name] + if "keypoint_rcnn" in name: + dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS + else: + dataset = ModelCatalog.C2_DATASET_COCO + + if "35998355/rpn_R-50-C4_1x" in name: + # this one model is somehow different from others .. + type = "rpn" + else: + type = "generalized_rcnn" + + # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`. + url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format( + prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset + ) + return url + + +class ModelCatalogHandler(PathHandler): + """ + Resolve URL like catalog://. + """ + + PREFIX = "catalog://" + + def _get_supported_prefixes(self): + return [self.PREFIX] + + def _get_local_path(self, path, **kwargs): + logger = logging.getLogger(__name__) + catalog_path = ModelCatalog.get(path[len(self.PREFIX) :]) + logger.info("Catalog entry {} points to {}".format(path, catalog_path)) + return PathManager.get_local_path(catalog_path, **kwargs) + + def _open(self, path, mode="r", **kwargs): + return PathManager.open(self._get_local_path(path), mode, **kwargs) + + +PathManager.register_handler(ModelCatalogHandler()) diff --git a/data_processing/detectron2/detectron2/checkpoint/detection_checkpoint.py b/data_processing/detectron2/detectron2/checkpoint/detection_checkpoint.py new file mode 100644 index 0000000..cecb1fc --- /dev/null +++ b/data_processing/detectron2/detectron2/checkpoint/detection_checkpoint.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os +import pickle +from urllib.parse import parse_qs, urlparse +import torch +from fvcore.common.checkpoint import Checkpointer +from torch.nn.parallel import DistributedDataParallel + +import detectron2.utils.comm as comm +from detectron2.utils.file_io import PathManager + +from .c2_model_loading import align_and_update_state_dicts + + +class DetectionCheckpointer(Checkpointer): + """ + Same as :class:`Checkpointer`, but is able to: + 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. + 2. correctly load checkpoints that are only available on the master worker + """ + + def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): + is_main_process = comm.is_main_process() + super().__init__( + model, + save_dir, + save_to_disk=is_main_process if save_to_disk is None else save_to_disk, + **checkpointables, + ) + self.path_manager = PathManager + self._parsed_url_during_load = None + + def load(self, path, *args, **kwargs): + assert self._parsed_url_during_load is None + need_sync = False + logger = logging.getLogger(__name__) + logger.info("[DetectionCheckpointer] Loading from {} ...".format(path)) + + if path and isinstance(self.model, DistributedDataParallel): + path = self.path_manager.get_local_path(path) + has_file = os.path.isfile(path) + all_has_file = comm.all_gather(has_file) + if not all_has_file[0]: + raise OSError(f"File {path} not found on main worker.") + if not all(all_has_file): + logger.warning( + f"Not all workers can read checkpoint {path}. " + "Training may fail to fully resume." + ) + # TODO: broadcast the checkpoint file contents from main + # worker, and load from it instead. + need_sync = True + if not has_file: + path = None # don't load if not readable + + if path: + parsed_url = urlparse(path) + self._parsed_url_during_load = parsed_url + path = parsed_url._replace(query="").geturl() # remove query from filename + path = self.path_manager.get_local_path(path) + ret = super().load(path, *args, **kwargs) + + if need_sync: + logger.info("Broadcasting model states from main worker ...") + self.model._sync_params_and_buffers() + self._parsed_url_during_load = None # reset to None + return ret + + def _load_file(self, filename): + if filename.endswith(".pkl"): + with PathManager.open(filename, "rb") as f: + data = pickle.load(f, encoding="latin1") + if "model" in data and "__author__" in data: + # file is in Detectron2 model zoo format + self.logger.info("Reading a file from '{}'".format(data["__author__"])) + return data + else: + # assume file is from Caffe2 / Detectron1 model zoo + if "blobs" in data: + # Detection models have "blobs", but ImageNet models don't + data = data["blobs"] + data = {k: v for k, v in data.items() if not k.endswith("_momentum")} + return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} + elif filename.endswith(".pyth"): + # assume file is from pycls; no one else seems to use the ".pyth" extension + with PathManager.open(filename, "rb") as f: + data = torch.load(f) + assert ( + "model_state" in data + ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." + model_state = { + k: v + for k, v in data["model_state"].items() + if not k.endswith("num_batches_tracked") + } + return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} + + loaded = self._torch_load(filename) + if "model" not in loaded: + loaded = {"model": loaded} + assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`" + parsed_url = self._parsed_url_during_load + queries = parse_qs(parsed_url.query) + if queries.pop("matching_heuristics", "False") == ["True"]: + loaded["matching_heuristics"] = True + if len(queries) > 0: + raise ValueError( + f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}" + ) + return loaded + + def _torch_load(self, f): + return super()._load_file(f) + + def _load_model(self, checkpoint): + if checkpoint.get("matching_heuristics", False): + self._convert_ndarray_to_tensor(checkpoint["model"]) + # convert weights by name-matching heuristics + checkpoint["model"] = align_and_update_state_dicts( + self.model.state_dict(), + checkpoint["model"], + c2_conversion=checkpoint.get("__author__", None) == "Caffe2", + ) + # for non-caffe2 models, use standard ways to load it + incompatible = super()._load_model(checkpoint) + + model_buffers = dict(self.model.named_buffers(recurse=False)) + for k in ["pixel_mean", "pixel_std"]: + # Ignore missing key message about pixel_mean/std. + # Though they may be missing in old checkpoints, they will be correctly + # initialized from config anyway. + if k in model_buffers: + try: + incompatible.missing_keys.remove(k) + except ValueError: + pass + for k in incompatible.unexpected_keys[:]: + # Ignore unexpected keys about cell anchors. They exist in old checkpoints + # but now they are non-persistent buffers and will not be in new checkpoints. + if "anchor_generator.cell_anchors" in k: + incompatible.unexpected_keys.remove(k) + return incompatible diff --git a/data_processing/detectron2/detectron2/config/__init__.py b/data_processing/detectron2/detectron2/config/__init__.py new file mode 100644 index 0000000..4e648e6 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .compat import downgrade_config, upgrade_config +from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable +from .instantiate import instantiate +from .lazy import LazyCall, LazyConfig + +__all__ = [ + "CfgNode", + "get_cfg", + "global_cfg", + "set_global_cfg", + "downgrade_config", + "upgrade_config", + "configurable", + "instantiate", + "LazyCall", + "LazyConfig", +] + + +from detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/data_processing/detectron2/detectron2/config/compat.py b/data_processing/detectron2/detectron2/config/compat.py new file mode 100644 index 0000000..11a08c4 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/compat.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Backward compatibility of configs. + +Instructions to bump version: ++ It's not needed to bump version if new keys are added. + It's only needed when backward-incompatible changes happen + (i.e., some existing keys disappear, or the meaning of a key changes) ++ To bump version, do the following: + 1. Increment _C.VERSION in defaults.py + 2. Add a converter in this file. + + Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X, + and a function "downgrade" which in-place downgrades config from X to X-1 + + In each function, VERSION is left unchanged. + + Each converter assumes that its input has the relevant keys + (i.e., the input is not a partial config). + 3. Run the tests (test_config.py) to make sure the upgrade & downgrade + functions are consistent. +""" + +import logging +from typing import List, Optional, Tuple + +from .config import CfgNode as CN +from .defaults import _C + +__all__ = ["upgrade_config", "downgrade_config"] + + +def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN: + """ + Upgrade a config from its current version to a newer version. + + Args: + cfg (CfgNode): + to_version (int): defaults to the latest version. + """ + cfg = cfg.clone() + if to_version is None: + to_version = _C.VERSION + + assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format( + cfg.VERSION, to_version + ) + for k in range(cfg.VERSION, to_version): + converter = globals()["ConverterV" + str(k + 1)] + converter.upgrade(cfg) + cfg.VERSION = k + 1 + return cfg + + +def downgrade_config(cfg: CN, to_version: int) -> CN: + """ + Downgrade a config from its current version to an older version. + + Args: + cfg (CfgNode): + to_version (int): + + Note: + A general downgrade of arbitrary configs is not always possible due to the + different functionalities in different versions. + The purpose of downgrade is only to recover the defaults in old versions, + allowing it to load an old partial yaml config. + Therefore, the implementation only needs to fill in the default values + in the old version when a general downgrade is not possible. + """ + cfg = cfg.clone() + assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format( + cfg.VERSION, to_version + ) + for k in range(cfg.VERSION, to_version, -1): + converter = globals()["ConverterV" + str(k)] + converter.downgrade(cfg) + cfg.VERSION = k - 1 + return cfg + + +def guess_version(cfg: CN, filename: str) -> int: + """ + Guess the version of a partial config where the VERSION field is not specified. + Returns the version, or the latest if cannot make a guess. + + This makes it easier for users to migrate. + """ + logger = logging.getLogger(__name__) + + def _has(name: str) -> bool: + cur = cfg + for n in name.split("."): + if n not in cur: + return False + cur = cur[n] + return True + + # Most users' partial configs have "MODEL.WEIGHT", so guess on it + ret = None + if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"): + ret = 1 + + if ret is not None: + logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret)) + else: + ret = _C.VERSION + logger.warning( + "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format( + filename, ret + ) + ) + return ret + + +def _rename(cfg: CN, old: str, new: str) -> None: + old_keys = old.split(".") + new_keys = new.split(".") + + def _set(key_seq: List[str], val: str) -> None: + cur = cfg + for k in key_seq[:-1]: + if k not in cur: + cur[k] = CN() + cur = cur[k] + cur[key_seq[-1]] = val + + def _get(key_seq: List[str]) -> CN: + cur = cfg + for k in key_seq: + cur = cur[k] + return cur + + def _del(key_seq: List[str]) -> None: + cur = cfg + for k in key_seq[:-1]: + cur = cur[k] + del cur[key_seq[-1]] + if len(cur) == 0 and len(key_seq) > 1: + _del(key_seq[:-1]) + + _set(new_keys, _get(old_keys)) + _del(old_keys) + + +class _RenameConverter: + """ + A converter that handles simple rename. + """ + + RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name) + + @classmethod + def upgrade(cls, cfg: CN) -> None: + for old, new in cls.RENAME: + _rename(cfg, old, new) + + @classmethod + def downgrade(cls, cfg: CN) -> None: + for old, new in cls.RENAME[::-1]: + _rename(cfg, new, old) + + +class ConverterV1(_RenameConverter): + RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")] + + +class ConverterV2(_RenameConverter): + """ + A large bulk of rename, before public release. + """ + + RENAME = [ + ("MODEL.WEIGHT", "MODEL.WEIGHTS"), + ("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"), + ( + "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD", + "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH", + ), + ( + "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT", + "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT", + ), + ( + "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD", + "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH", + ), + ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"), + ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"), + ("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"), + ("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"), + ("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"), + ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"), + ("TEST.AUG_ON", "TEST.AUG.ENABLED"), + ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"), + ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"), + ("TEST.AUG_FLIP", "TEST.AUG.FLIP"), + ] + + @classmethod + def upgrade(cls, cfg: CN) -> None: + super().upgrade(cfg) + + if cfg.MODEL.META_ARCHITECTURE == "RetinaNet": + _rename( + cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS" + ) + _rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES") + del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"] + del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"] + else: + _rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS") + _rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES") + del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"] + del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"] + del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"] + + @classmethod + def downgrade(cls, cfg: CN) -> None: + super().downgrade(cfg) + + _rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS") + _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES") + cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS + cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES + cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version diff --git a/data_processing/detectron2/detectron2/config/config.py b/data_processing/detectron2/detectron2/config/config.py new file mode 100644 index 0000000..49a55b1 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/config.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import inspect +import logging +from fvcore.common.config import CfgNode as _CfgNode + +from detectron2.utils.file_io import PathManager + + +class CfgNode(_CfgNode): + """ + The same as `fvcore.common.config.CfgNode`, but different in: + + 1. Use unsafe yaml loading by default. + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + 2. Support config versioning. + When attempting to merge an old config, it will convert the old config automatically. + + .. automethod:: clone + .. automethod:: freeze + .. automethod:: defrost + .. automethod:: is_frozen + .. automethod:: load_yaml_with_base + .. automethod:: merge_from_list + .. automethod:: merge_from_other_cfg + """ + + @classmethod + def _open_cfg(cls, filename): + return PathManager.open(filename, "r") + + # Note that the default value of allow_unsafe is changed to True + def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: + """ + Load content from the given config file and merge it into self. + + Args: + cfg_filename: config filename + allow_unsafe: allow unsafe yaml syntax + """ + assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!" + loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) + loaded_cfg = type(self)(loaded_cfg) + + # defaults.py needs to import CfgNode + from .defaults import _C + + latest_ver = _C.VERSION + assert ( + latest_ver == self.VERSION + ), "CfgNode.merge_from_file is only allowed on a config object of latest version!" + + logger = logging.getLogger(__name__) + + loaded_ver = loaded_cfg.get("VERSION", None) + if loaded_ver is None: + from .compat import guess_version + + loaded_ver = guess_version(loaded_cfg, cfg_filename) + assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format( + loaded_ver, self.VERSION + ) + + if loaded_ver == self.VERSION: + self.merge_from_other_cfg(loaded_cfg) + else: + # compat.py needs to import CfgNode + from .compat import upgrade_config, downgrade_config + + logger.warning( + "Loading an old v{} config file '{}' by automatically upgrading to v{}. " + "See docs/CHANGELOG.md for instructions to update your files.".format( + loaded_ver, cfg_filename, self.VERSION + ) + ) + # To convert, first obtain a full config at an old version + old_self = downgrade_config(self, to_version=loaded_ver) + old_self.merge_from_other_cfg(loaded_cfg) + new_config = upgrade_config(old_self) + self.clear() + self.update(new_config) + + def dump(self, *args, **kwargs): + """ + Returns: + str: a yaml string representation of the config + """ + # to make it show up in docs + return super().dump(*args, **kwargs) + + +global_cfg = CfgNode() + + +def get_cfg() -> CfgNode: + """ + Get a copy of the default config. + + Returns: + a detectron2 CfgNode instance. + """ + from .defaults import _C + + return _C.clone() + + +def set_global_cfg(cfg: CfgNode) -> None: + """ + Let the global config point to the given cfg. + + Assume that the given "cfg" has the key "KEY", after calling + `set_global_cfg(cfg)`, the key can be accessed by: + :: + from detectron2.config import global_cfg + print(global_cfg.KEY) + + By using a hacky global config, you can access these configs anywhere, + without having to pass the config object or the values deep into the code. + This is a hacky feature introduced for quick prototyping / research exploration. + """ + global global_cfg + global_cfg.clear() + global_cfg.update(cfg) + + +def configurable(init_func=None, *, from_config=None): + """ + Decorate a function or a class's __init__ method so that it can be called + with a :class:`CfgNode` object using a :func:`from_config` function that translates + :class:`CfgNode` to arguments. + + Examples: + :: + # Usage 1: Decorator on __init__: + class A: + @configurable + def __init__(self, a, b=2, c=3): + pass + + @classmethod + def from_config(cls, cfg): # 'cfg' must be the first argument + # Returns kwargs to be passed to __init__ + return {"a": cfg.A, "b": cfg.B} + + a1 = A(a=1, b=2) # regular construction + a2 = A(cfg) # construct with a cfg + a3 = A(cfg, b=3, c=4) # construct with extra overwrite + + # Usage 2: Decorator on any function. Needs an extra from_config argument: + @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) + def a_func(a, b=2, c=3): + pass + + a1 = a_func(a=1, b=2) # regular call + a2 = a_func(cfg) # call with a cfg + a3 = a_func(cfg, b=3, c=4) # call with extra overwrite + + Args: + init_func (callable): a class's ``__init__`` method in usage 1. The + class must have a ``from_config`` classmethod which takes `cfg` as + the first argument. + from_config (callable): the from_config function in usage 2. It must take `cfg` + as its first argument. + """ + + if init_func is not None: + assert ( + inspect.isfunction(init_func) + and from_config is None + and init_func.__name__ == "__init__" + ), "Incorrect use of @configurable. Check API documentation for examples." + + @functools.wraps(init_func) + def wrapped(self, *args, **kwargs): + try: + from_config_func = type(self).from_config + except AttributeError as e: + raise AttributeError( + "Class with @configurable must have a 'from_config' classmethod." + ) from e + if not inspect.ismethod(from_config_func): + raise TypeError("Class with @configurable must have a 'from_config' classmethod.") + + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) + init_func(self, **explicit_args) + else: + init_func(self, *args, **kwargs) + + return wrapped + + else: + if from_config is None: + return configurable # @configurable() is made equivalent to @configurable + assert inspect.isfunction( + from_config + ), "from_config argument of configurable must be a function!" + + def wrapper(orig_func): + @functools.wraps(orig_func) + def wrapped(*args, **kwargs): + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config, *args, **kwargs) + return orig_func(**explicit_args) + else: + return orig_func(*args, **kwargs) + + wrapped.from_config = from_config + return wrapped + + return wrapper + + +def _get_args_from_config(from_config_func, *args, **kwargs): + """ + Use `from_config` to obtain explicit arguments. + + Returns: + dict: arguments to be used for cls.__init__ + """ + signature = inspect.signature(from_config_func) + if list(signature.parameters.keys())[0] != "cfg": + if inspect.isfunction(from_config_func): + name = from_config_func.__name__ + else: + name = f"{from_config_func.__self__}.from_config" + raise TypeError(f"{name} must take 'cfg' as the first argument!") + support_var_arg = any( + param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] + for param in signature.parameters.values() + ) + if support_var_arg: # forward all arguments to from_config, if from_config accepts them + ret = from_config_func(*args, **kwargs) + else: + # forward supported arguments to from_config + supported_arg_names = set(signature.parameters.keys()) + extra_kwargs = {} + for name in list(kwargs.keys()): + if name not in supported_arg_names: + extra_kwargs[name] = kwargs.pop(name) + ret = from_config_func(*args, **kwargs) + # forward the other arguments to __init__ + ret.update(extra_kwargs) + return ret + + +def _called_with_cfg(*args, **kwargs): + """ + Returns: + bool: whether the arguments contain CfgNode and should be considered + forwarded to from_config. + """ + from omegaconf import DictConfig + + if len(args) and isinstance(args[0], (_CfgNode, DictConfig)): + return True + if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)): + return True + # `from_config`'s first argument is forced to be "cfg". + # So the above check covers all cases. + return False diff --git a/data_processing/detectron2/detectron2/config/defaults.py b/data_processing/detectron2/detectron2/config/defaults.py new file mode 100644 index 0000000..bd2a5f6 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/defaults.py @@ -0,0 +1,650 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import CfgNode as CN + +# NOTE: given the new config system +# (https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html), +# we will stop adding new functionalities to default CfgNode. + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +# The version number, to upgrade from old configs to new ones if any +# changes happen. It's recommended to keep a VERSION in your config file. +_C.VERSION = 2 + +_C.MODEL = CN() +_C.MODEL.LOAD_PROPOSALS = False +_C.MODEL.MASK_ON = False +_C.MODEL.KEYPOINT_ON = False +_C.MODEL.DEVICE = "cuda" +_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" + +# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file +# to be loaded to the model. You can find available models in the model zoo. +_C.MODEL.WEIGHTS = "" + +# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR). +# To train on images of different number of channels, just set different mean & std. +# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] +_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] +# When using pre-trained models in Detectron1 or any MSRA models, +# std has been absorbed into its conv1 weights, so the std needs to be set 1. +# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) +_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge. +# Please refer to ResizeShortestEdge for detailed definition. +# Size of the smallest side of the image during training +_C.INPUT.MIN_SIZE_TRAIN = (800,) +# Sample size of smallest side by choice or random selection from range give by +# INPUT.MIN_SIZE_TRAIN +_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice" +# Maximum size of the side of the image during training +_C.INPUT.MAX_SIZE_TRAIN = 1333 +# Size of the smallest side of the image during testing. Set to zero to disable resize in testing. +_C.INPUT.MIN_SIZE_TEST = 800 +# Maximum size of the side of the image during testing +_C.INPUT.MAX_SIZE_TEST = 1333 +# Mode for flipping images used in data augmentation during training +# choose one of ["horizontal, "vertical", "none"] +_C.INPUT.RANDOM_FLIP = "horizontal" + +# `True` if cropping is used for data augmentation during training +_C.INPUT.CROP = CN({"ENABLED": False}) +# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation. +_C.INPUT.CROP.TYPE = "relative_range" +# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of +# pixels if CROP.TYPE is "absolute" +_C.INPUT.CROP.SIZE = [0.9, 0.9] + + +# Whether the model needs RGB, YUV, HSV etc. +# Should be one of the modes defined here, as we use PIL to read the image: +# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes +# with BGR being the one exception. One can set image format to BGR, we will +# internally use RGB for conversion and flip the channels over +_C.INPUT.FORMAT = "BGR" +# The ground truth mask format that the model will use. +# Mask R-CNN supports either "polygon" or "bitmask" as ground truth. +_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask" + + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training. Must be registered in DatasetCatalog +# Samples from these datasets will be merged and used as one dataset. +_C.DATASETS.TRAIN = () +# List of the pre-computed proposal files for training, which must be consistent +# with datasets listed in DATASETS.TRAIN. +_C.DATASETS.PROPOSAL_FILES_TRAIN = () +# Number of top scoring precomputed proposals to keep for training +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000 +# List of the dataset names for testing. Must be registered in DatasetCatalog +_C.DATASETS.TEST = () +# List of the pre-computed proposal files for test, which must be consistent +# with datasets listed in DATASETS.TEST. +_C.DATASETS.PROPOSAL_FILES_TEST = () +# Number of top scoring precomputed proposals to keep for test +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000 + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +# If True, each batch should contain only images for which the aspect ratio +# is compatible. This groups portrait images together, and landscape images +# are not batched with portrait images. +_C.DATALOADER.ASPECT_RATIO_GROUPING = True +# Options: TrainingSampler, RepeatFactorTrainingSampler +_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler" +# Repeat threshold for RepeatFactorTrainingSampler +_C.DATALOADER.REPEAT_THRESHOLD = 0.0 +# Tf True, when working on datasets that have instance annotations, the +# training dataloader will filter out images without associated annotations +_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True + +# ---------------------------------------------------------------------------- # +# Backbone options +# ---------------------------------------------------------------------------- # +_C.MODEL.BACKBONE = CN() + +_C.MODEL.BACKBONE.NAME = "build_resnet_backbone" +# Freeze the first several stages so they are not trained. +# There are 5 stages in ResNet. The first is a convolution, and the following +# stages are each group of residual blocks. +_C.MODEL.BACKBONE.FREEZE_AT = 2 + + +# ---------------------------------------------------------------------------- # +# FPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.FPN = CN() +# Names of the input feature maps to be used by FPN +# They must have contiguous power of 2 strides +# e.g., ["res2", "res3", "res4", "res5"] +_C.MODEL.FPN.IN_FEATURES = [] +_C.MODEL.FPN.OUT_CHANNELS = 256 + +# Options: "" (no norm), "GN" +_C.MODEL.FPN.NORM = "" + +# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg" +_C.MODEL.FPN.FUSE_TYPE = "sum" + + +# ---------------------------------------------------------------------------- # +# Proposal generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.PROPOSAL_GENERATOR = CN() +# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals" +_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN" +# Proposal height and width both need to be greater than MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0 + + +# ---------------------------------------------------------------------------- # +# Anchor generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.ANCHOR_GENERATOR = CN() +# The generator can be any name in the ANCHOR_GENERATOR registry +_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator" +# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input. +# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for +# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1. +# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]] +# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect +# ratios are generated by an anchor generator. +# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W) +# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true, +# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used +# for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]] +# Anchor angles. +# list[list[float]], the angle in degrees, for each input feature map. +# ANGLES[i] specifies the list of angles for IN_FEATURES[i]. +_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]] +# Relative offset between the center of the first anchor and the top-left corner of the image +# Value has to be in [0, 1). Recommend to use 0.5, which means half stride. +# The value is not expected to affect model accuracy. +_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0 + +# ---------------------------------------------------------------------------- # +# RPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.RPN = CN() +_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY + +# Names of the input feature maps to be used by RPN +# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN +_C.MODEL.RPN.IN_FEATURES = ["res4"] +# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels +# Set to -1 or a large value, e.g. 100000, to disable pruning anchors +_C.MODEL.RPN.BOUNDARY_THRESH = -1 +# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD] +# Minimum overlap required between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD +# ==> positive RPN example: 1) +# Maximum overlap allowed between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD +# ==> negative RPN example: 0) +# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD) +# are ignored (-1) +_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7] +_C.MODEL.RPN.IOU_LABELS = [0, -1, 1] +# Number of regions per image used to train RPN +_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per RPN minibatch +_C.MODEL.RPN.POSITIVE_FRACTION = 0.5 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1" +_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0 +# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets +_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0 +_C.MODEL.RPN.LOSS_WEIGHT = 1.0 +# Number of top scoring RPN proposals to keep before applying NMS +# When FPN is used, this is *per FPN level* (not total) +_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000 +_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000 +# Number of top scoring RPN proposals to keep after applying NMS +# When FPN is used, this limit is applied per level and then again to the union +# of proposals from all levels +# NOTE: When FPN is used, the meaning of this config is different from Detectron1. +# It means per-batch topk in Detectron1, but per-image topk here. +# See the "find_top_rpn_proposals" function for details. +_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000 +_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000 +# NMS threshold used on RPN proposals +_C.MODEL.RPN.NMS_THRESH = 0.7 +# Set this to -1 to use the same number of output channels as input channels. +_C.MODEL.RPN.CONV_DIMS = [-1] + +# ---------------------------------------------------------------------------- # +# ROI HEADS options +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_HEADS = CN() +_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads" +# Number of foreground classes +_C.MODEL.ROI_HEADS.NUM_CLASSES = 80 +# Names of the input feature maps to be used by ROI heads +# Currently all heads (box, mask, ...) use the same input feature map list +# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN +_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"] +# IOU overlap ratios [IOU_THRESHOLD] +# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD) +# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD) +_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5] +_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1] +# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training +# Total number of RoIs per training minibatch = +# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH +# E.g., a common configuration is: 512 * 16 = 8192 +_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 +# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) +_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 + +# Only used on test mode + +# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to +# balance obtaining high recall with not having too many low precision +# detections that will slow down inference post processing steps (like NMS) +# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down +# inference. +_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 +# Overlap threshold used for non-maximum suppression (suppress boxes with +# IoU >= this threshold) +_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5 +# If True, augment proposals with ground-truth boxes before sampling proposals to +# train ROI heads. +_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True + +# ---------------------------------------------------------------------------- # +# Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_HEAD = CN() +# C4 don't use head name option +# Options for non-C4 models: FastRCNNConvFCHead, +_C.MODEL.ROI_BOX_HEAD.NAME = "" +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1" +# The final scaling coefficient on the box regression loss, used to balance the magnitude of its +# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`. +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0 +# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets +# These are empirically chosen to approximately lead to unit variance targets +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0 +_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2" + +_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0 +# Hidden layer dimension for FC layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024 +_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0 +# Channel dimension for Conv layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_BOX_HEAD.NORM = "" +# Whether to use class agnostic for bbox regression +_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False +# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes. +_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False + +# Federated loss can be used to improve the training of LVIS +_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False +# Sigmoid cross entrophy is used with federated loss +_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False +# The power value applied to image_count when calcualting frequency weight +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER = 0.5 +# Number of classes to keep in total +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES = 50 + +# ---------------------------------------------------------------------------- # +# Cascaded Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_CASCADE_HEAD = CN() +# The number of cascade stages is implicitly defined by the length of the following two configs. +_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = ( + (10.0, 10.0, 5.0, 5.0), + (20.0, 20.0, 10.0, 10.0), + (30.0, 30.0, 15.0, 15.0), +) +_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7) + + +# ---------------------------------------------------------------------------- # +# Mask Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_MASK_HEAD = CN() +_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead" +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head +_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_MASK_HEAD.NORM = "" +# Whether to use class agnostic for mask prediction +_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2" + + +# ---------------------------------------------------------------------------- # +# Keypoint Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_KEYPOINT_HEAD = CN() +_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead" +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8)) +_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO. + +# Images with too few (or no) keypoints are excluded from training. +_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1 +# Normalize by the total number of visible keypoints in the minibatch if True. +# Otherwise, normalize by the total number of keypoints that could ever exist +# in the minibatch. +# The keypoint softmax loss is only calculated on visible keypoints. +# Since the number of visible keypoints can vary significantly between +# minibatches, this has the effect of up-weighting the importance of +# minibatches with few visible keypoints. (Imagine the extreme case of +# only one visible keypoint versus N: in the case of N, each one +# contributes 1/N to the gradient compared to the single keypoint +# determining the gradient direction). Instead, we can normalize the +# loss by the total number of keypoints, if it were the case that all +# keypoints were visible in a full minibatch. (Returning to the example, +# this means that the one visible keypoint contributes as much as each +# of the N keypoints.) +_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True +# Multi-task loss weight to use for keypoints +# Recommended values: +# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True +# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False +_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2" + +# ---------------------------------------------------------------------------- # +# Semantic Segmentation Head +# ---------------------------------------------------------------------------- # +_C.MODEL.SEM_SEG_HEAD = CN() +_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead" +_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"] +# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for +# the correposnding pixel. +_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255 +# Number of classes in the semantic segmentation head +_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54 +# Number of channels in the 3x3 convs inside semantic-FPN heads. +_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128 +# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride. +_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 +# Normalization method for the convolution layers. Options: "" (no norm), "GN". +_C.MODEL.SEM_SEG_HEAD.NORM = "GN" +_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0 + +_C.MODEL.PANOPTIC_FPN = CN() +# Scaling of all losses from instance detection / segmentation head. +_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0 + +# options when combining instance & semantic segmentation outputs +_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used +_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5 +_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096 +_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5 + + +# ---------------------------------------------------------------------------- # +# RetinaNet Head +# ---------------------------------------------------------------------------- # +_C.MODEL.RETINANET = CN() + +# This is the number of foreground classes. +_C.MODEL.RETINANET.NUM_CLASSES = 80 + +_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] + +# Convolutions to use in the cls and bbox tower +# NOTE: this doesn't include the last conv for logits +_C.MODEL.RETINANET.NUM_CONVS = 4 + +# IoU overlap ratio [bg, fg] for labeling anchors. +# Anchors with < bg are labeled negative (0) +# Anchors with >= bg and < fg are ignored (-1) +# Anchors with >= fg are labeled positive (1) +_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5] +_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1] + +# Prior prob for rare case (i.e. foreground) at the beginning of training. +# This is used to set the bias for the logits layer of the classifier subnet. +# This improves training stability in the case of heavy class imbalance. +_C.MODEL.RETINANET.PRIOR_PROB = 0.01 + +# Inference cls score threshold, only anchors with score > INFERENCE_TH are +# considered for inference (to improve speed) +_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05 +# Select topk candidates before NMS +_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000 +_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5 + +# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets +_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) + +# Loss parameters +_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0 +_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25 +_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1" + +# One of BN, SyncBN, FrozenBN, GN +# Only supports GN until unshared norm is implemented +_C.MODEL.RETINANET.NORM = "" + + +# ---------------------------------------------------------------------------- # +# ResNe[X]t options (ResNets = {ResNet, ResNeXt} +# Note that parts of a resnet may be used for both the backbone and the head +# These options apply to both +# ---------------------------------------------------------------------------- # +_C.MODEL.RESNETS = CN() + +_C.MODEL.RESNETS.DEPTH = 50 +_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone + +# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt +_C.MODEL.RESNETS.NUM_GROUPS = 1 + +# Options: FrozenBN, GN, "SyncBN", "BN" +_C.MODEL.RESNETS.NORM = "FrozenBN" + +# Baseline width of each group. +# Scaling this parameters will scale the width of all bottleneck layers. +_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64 + +# Place the stride 2 conv on the 1x1 filter +# Use True only for the original MSRA ResNet; use False for C2 and Torch models +_C.MODEL.RESNETS.STRIDE_IN_1X1 = True + +# Apply dilation in stage "res5" +_C.MODEL.RESNETS.RES5_DILATION = 1 + +# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet +# For R18 and R34, this needs to be set to 64 +_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 +_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 + +# Apply Deformable Convolution in stages +# Specify if apply deform_conv on Res2, Res3, Res4, Res5 +_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False] +# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168); +# Use False for DeformableV1. +_C.MODEL.RESNETS.DEFORM_MODULATED = False +# Number of groups in deformable conv. +_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1 + + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() + +# Options: WarmupMultiStepLR, WarmupCosineLR. +# See detectron2/solver/build.py for definition. +_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR" + +_C.SOLVER.MAX_ITER = 40000 + +_C.SOLVER.BASE_LR = 0.001 +# The end lr, only used by WarmupCosineLR +_C.SOLVER.BASE_LR_END = 0.0 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.NESTEROV = False + +_C.SOLVER.WEIGHT_DECAY = 0.0001 +# The weight decay that's applied to parameters of normalization layers +# (typically the affine transformation) +_C.SOLVER.WEIGHT_DECAY_NORM = 0.0 + +_C.SOLVER.GAMMA = 0.1 +# The iteration number to decrease learning rate by GAMMA. +_C.SOLVER.STEPS = (30000,) +# Number of decays in WarmupStepWithFixedGammaLR schedule +_C.SOLVER.NUM_DECAYS = 3 + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000 +_C.SOLVER.WARMUP_ITERS = 1000 +_C.SOLVER.WARMUP_METHOD = "linear" +# Whether to rescale the interval for the learning schedule after warmup +_C.SOLVER.RESCALE_INTERVAL = False + +# Save a checkpoint after every this number of iterations +_C.SOLVER.CHECKPOINT_PERIOD = 5000 + +# Number of images per batch across all machines. This is also the number +# of training images per step (i.e. per iteration). If we use 16 GPUs +# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch. +# May be adjusted automatically if REFERENCE_WORLD_SIZE is set. +_C.SOLVER.IMS_PER_BATCH = 16 + +# The reference number of workers (GPUs) this config is meant to train with. +# It takes no effect when set to 0. +# With a non-zero value, it will be used by DefaultTrainer to compute a desired +# per-worker batch size, and then scale the other related configs (total batch size, +# learning rate, etc) to match the per-worker batch size. +# See documentation of `DefaultTrainer.auto_scale_workers` for details: +_C.SOLVER.REFERENCE_WORLD_SIZE = 0 + +# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for +# biases. This is not useful (at least for recent models). You should avoid +# changing these and they exist only to reproduce Detectron v1 training if +# desired. +_C.SOLVER.BIAS_LR_FACTOR = 1.0 +_C.SOLVER.WEIGHT_DECAY_BIAS = None # None means following WEIGHT_DECAY + +# Gradient clipping +_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False}) +# Type of gradient clipping, currently 2 values are supported: +# - "value": the absolute values of elements of each gradients are clipped +# - "norm": the norm of the gradient for each parameter is clipped thus +# affecting all elements in the parameter +_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value" +# Maximum absolute value used for clipping gradients +_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0 +# Floating point number p for L-p norm to be used with the "norm" +# gradient clipping type; for L-inf, please specify .inf +_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0 + +# Enable automatic mixed precision for training +# Note that this does not change model's inference behavior. +# To use AMP in inference, run inference under autocast() +_C.SOLVER.AMP = CN({"ENABLED": False}) + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +# For end-to-end tests to verify the expected accuracy. +# Each item is [task, metric, value, tolerance] +# e.g.: [['bbox', 'AP', 38.5, 0.2]] +_C.TEST.EXPECTED_RESULTS = [] +# The period (in terms of steps) to evaluate the model during training. +# Set to 0 to disable. +_C.TEST.EVAL_PERIOD = 0 +# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval +# When empty, it will use the defaults in COCO. +# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. +_C.TEST.KEYPOINT_OKS_SIGMAS = [] +# Maximum number of detections to return per image during inference (100 is +# based on the limit established for the COCO dataset). +_C.TEST.DETECTIONS_PER_IMAGE = 100 + +_C.TEST.AUG = CN({"ENABLED": False}) +_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200) +_C.TEST.AUG.MAX_SIZE = 4000 +_C.TEST.AUG.FLIP = True + +_C.TEST.PRECISE_BN = CN({"ENABLED": False}) +_C.TEST.PRECISE_BN.NUM_ITER = 200 + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +# Directory where output files are written +_C.OUTPUT_DIR = "./output" +# Set seed to negative to fully randomize everything. +# Set seed to positive to use a fixed seed. Note that a fixed seed increases +# reproducibility but does not guarantee fully deterministic behavior. +# Disabling all parallelism further increases reproducibility. +_C.SEED = -1 +# Benchmark different cudnn algorithms. +# If input images have very different sizes, this option will have large overhead +# for about 10k iterations. It usually hurts total time, but can benefit for certain models. +# If input images have the same or similar sizes, benchmark is often helpful. +_C.CUDNN_BENCHMARK = False +# The period (in terms of steps) for minibatch visualization at train time. +# Set to 0 to disable. +_C.VIS_PERIOD = 0 + +# global config is for quick hack purposes. +# You can set them in command line or config files, +# and access it with: +# +# from detectron2.config import global_cfg +# print(global_cfg.HACK) +# +# Do not commit any configs into it. +_C.GLOBAL = CN() +_C.GLOBAL.HACK = 1.0 diff --git a/data_processing/detectron2/detectron2/config/instantiate.py b/data_processing/detectron2/detectron2/config/instantiate.py new file mode 100644 index 0000000..05ee2c7 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/instantiate.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import collections.abc as abc +import dataclasses +import logging +from typing import Any + +from detectron2.utils.registry import _convert_target_to_string, locate + +__all__ = ["dump_dataclass", "instantiate"] + + +def dump_dataclass(obj: Any): + """ + Dump a dataclass recursively into a dict that can be later instantiated. + + Args: + obj: a dataclass object + + Returns: + dict + """ + assert dataclasses.is_dataclass(obj) and not isinstance( + obj, type + ), "dump_dataclass() requires an instance of a dataclass." + ret = {"_target_": _convert_target_to_string(type(obj))} + for f in dataclasses.fields(obj): + v = getattr(obj, f.name) + if dataclasses.is_dataclass(v): + v = dump_dataclass(v) + if isinstance(v, (list, tuple)): + v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] + ret[f.name] = v + return ret + + +def instantiate(cfg): + """ + Recursively instantiate objects defined in dictionaries by + "_target_" and arguments. + + Args: + cfg: a dict-like object with "_target_" that defines the caller, and + other keys that define the arguments + + Returns: + object instantiated by cfg + """ + from omegaconf import ListConfig, DictConfig, OmegaConf + + if isinstance(cfg, ListConfig): + lst = [instantiate(x) for x in cfg] + return ListConfig(lst, flags={"allow_objects": True}) + if isinstance(cfg, list): + # Specialize for list, because many classes take + # list[objects] as arguments, such as ResNet, DatasetMapper + return [instantiate(x) for x in cfg] + + # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), + # instantiate it to the actual dataclass. + if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type): + return OmegaConf.to_object(cfg) + + if isinstance(cfg, abc.Mapping) and "_target_" in cfg: + # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, + # but faster: https://github.com/facebookresearch/hydra/issues/1200 + cfg = {k: instantiate(v) for k, v in cfg.items()} + cls = cfg.pop("_target_") + cls = instantiate(cls) + + if isinstance(cls, str): + cls_name = cls + cls = locate(cls_name) + assert cls is not None, cls_name + else: + try: + cls_name = cls.__module__ + "." + cls.__qualname__ + except Exception: + # target could be anything, so the above could fail + cls_name = str(cls) + assert callable(cls), f"_target_ {cls} does not define a callable object" + try: + return cls(**cfg) + except TypeError: + logger = logging.getLogger(__name__) + logger.error(f"Error when instantiating {cls_name}!") + raise + return cfg # return as-is if don't know what to do diff --git a/data_processing/detectron2/detectron2/config/lazy.py b/data_processing/detectron2/detectron2/config/lazy.py new file mode 100644 index 0000000..ea93e86 --- /dev/null +++ b/data_processing/detectron2/detectron2/config/lazy.py @@ -0,0 +1,436 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import ast +import builtins +import collections.abc as abc +import importlib +import inspect +import logging +import os +import uuid +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import is_dataclass +from typing import List, Tuple, Union +import cloudpickle +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf, SCMode + +from detectron2.utils.file_io import PathManager +from detectron2.utils.registry import _convert_target_to_string + +__all__ = ["LazyCall", "LazyConfig"] + + +class LazyCall: + """ + Wrap a callable so that when it's called, the call will not be executed, + but returns a dict that describes the call. + + LazyCall object has to be called with only keyword arguments. Positional + arguments are not yet supported. + + Examples: + :: + from detectron2.config import instantiate, LazyCall + + layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) + layer_cfg.out_channels = 64 # can edit it afterwards + layer = instantiate(layer_cfg) + """ + + def __init__(self, target): + if not (callable(target) or isinstance(target, (str, abc.Mapping))): + raise TypeError( + f"target of LazyCall must be a callable or defines a callable! Got {target}" + ) + self._target = target + + def __call__(self, **kwargs): + if is_dataclass(self._target): + # omegaconf object cannot hold dataclass type + # https://github.com/omry/omegaconf/issues/784 + target = _convert_target_to_string(self._target) + else: + target = self._target + kwargs["_target_"] = target + + return DictConfig(content=kwargs, flags={"allow_objects": True}) + + +def _visit_dict_config(cfg, func): + """ + Apply func recursively to all DictConfig in cfg. + """ + if isinstance(cfg, DictConfig): + func(cfg) + for v in cfg.values(): + _visit_dict_config(v, func) + elif isinstance(cfg, ListConfig): + for v in cfg: + _visit_dict_config(v, func) + + +def _validate_py_syntax(filename): + # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + with PathManager.open(filename, "r") as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError(f"Config file {filename} has syntax error!") from e + + +def _cast_to_config(obj): + # if given a dict, return DictConfig instead + if isinstance(obj, dict): + return DictConfig(obj, flags={"allow_objects": True}) + return obj + + +_CFG_PACKAGE_NAME = "detectron2._cfg_loader" +""" +A namespace to put all imported config into. +""" + + +def _random_package_name(filename): + # generate a random package name when loading config files + return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) + + +@contextmanager +def _patch_import(): + """ + Enhance relative import statements in config files, so that they: + 1. locate files purely based on relative location, regardless of packages. + e.g. you can import file without having __init__ + 2. do not cache modules globally; modifications of module states has no side effect + 3. support other storage system through PathManager, so config files can be in the cloud + 4. imported dict are turned into omegaconf.DictConfig automatically + """ + old_import = builtins.__import__ + + def find_relative_file(original_file, relative_import_path, level): + # NOTE: "from . import x" is not handled. Because then it's unclear + # if such import should produce `x` as a python module or DictConfig. + # This can be discussed further if needed. + relative_import_err = """ +Relative import of directories is not allowed within config files. +Within a config file, relative import can only import other config files. +""".replace( + "\n", " " + ) + if not len(relative_import_path): + raise ImportError(relative_import_err) + + cur_file = os.path.dirname(original_file) + for _ in range(level - 1): + cur_file = os.path.dirname(cur_file) + cur_name = relative_import_path.lstrip(".") + for part in cur_name.split("."): + cur_file = os.path.join(cur_file, part) + if not cur_file.endswith(".py"): + cur_file += ".py" + if not PathManager.isfile(cur_file): + cur_file_no_suffix = cur_file[: -len(".py")] + if PathManager.isdir(cur_file_no_suffix): + raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) + else: + raise ImportError( + f"Cannot import name {relative_import_path} from " + f"{original_file}: {cur_file} does not exist." + ) + return cur_file + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + if ( + # Only deal with relative imports inside config files + level != 0 + and globals is not None + and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) + ): + cur_file = find_relative_file(globals["__file__"], name, level) + _validate_py_syntax(cur_file) + spec = importlib.machinery.ModuleSpec( + _random_package_name(cur_file), None, origin=cur_file + ) + module = importlib.util.module_from_spec(spec) + module.__file__ = cur_file + with PathManager.open(cur_file) as f: + content = f.read() + exec(compile(content, cur_file, "exec"), module.__dict__) + for name in fromlist: # turn imported dict into DictConfig automatically + val = _cast_to_config(module.__dict__[name]) + module.__dict__[name] = val + return module + return old_import(name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + yield new_import + builtins.__import__ = old_import + + +class LazyConfig: + """ + Provide methods to save, load, and overrides an omegaconf config object + which may contain definition of lazily-constructed objects. + """ + + @staticmethod + def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Similar to :meth:`load()`, but load path relative to the caller's + source file. + + This has the same functionality as a relative import, except that this method + accepts filename as a string, so more characters are allowed in the filename. + """ + caller_frame = inspect.stack()[1] + caller_fname = caller_frame[0].f_code.co_filename + assert caller_fname != "", "load_rel Unable to find caller" + caller_dir = os.path.dirname(caller_fname) + filename = os.path.join(caller_dir, filename) + return LazyConfig.load(filename, keys) + + @staticmethod + def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Load a config file. + + Args: + filename: absolute path or relative path w.r.t. the current working directory + keys: keys to load and return. If not given, return all keys + (whose values are config objects) in a dict. + """ + has_keys = keys is not None + filename = filename.replace("/./", "/") # redundant + if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: + raise ValueError(f"Config file {filename} has to be a python or yaml file.") + if filename.endswith(".py"): + _validate_py_syntax(filename) + + with _patch_import(): + # Record the filename + module_namespace = { + "__file__": filename, + "__package__": _random_package_name(filename), + } + with PathManager.open(filename) as f: + content = f.read() + # Compile first with filename to: + # 1. make filename appears in stacktrace + # 2. make load_rel able to find its parent's (possibly remote) location + exec(compile(content, filename, "exec"), module_namespace) + + ret = module_namespace + else: + with PathManager.open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + + if has_keys: + if isinstance(keys, str): + return _cast_to_config(ret[keys]) + else: + return tuple(_cast_to_config(ret[a]) for a in keys) + else: + if filename.endswith(".py"): + # when not specified, only load those that are config objects + ret = DictConfig( + { + name: _cast_to_config(value) + for name, value in ret.items() + if isinstance(value, (DictConfig, ListConfig, dict)) + and not name.startswith("_") + }, + flags={"allow_objects": True}, + ) + return ret + + @staticmethod + def save(cfg, filename: str): + """ + Save a config object to a yaml file. + Note that when the config dictionary contains complex objects (e.g. lambda), + it can't be saved to yaml. In that case we will print an error and + attempt to save to a pkl file instead. + + Args: + cfg: an omegaconf config object + filename: yaml file name to save the config file + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + else: + # if it's deep-copyable, then... + def _replace_type_by_name(x): + if "_target_" in x and callable(x._target_): + try: + x._target_ = _convert_target_to_string(x._target_) + except AttributeError: + pass + + # not necessary, but makes yaml looks nicer + _visit_dict_config(cfg, _replace_type_by_name) + + save_pkl = False + try: + dict = OmegaConf.to_container( + cfg, + # Do not resolve interpolation when saving, i.e. do not turn ${a} into + # actual values when saving. + resolve=False, + # Save structures (dataclasses) in a format that can be instantiated later. + # Without this option, the type information of the dataclass will be erased. + structured_config_mode=SCMode.INSTANTIATE, + ) + dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999) + with PathManager.open(filename, "w") as f: + f.write(dumped) + + try: + _ = yaml.unsafe_load(dumped) # test that it is loadable + except Exception: + logger.warning( + "The config contains objects that cannot serialize to a valid yaml. " + f"{filename} is human-readable but cannot be loaded." + ) + save_pkl = True + except Exception: + logger.exception("Unable to serialize the config to yaml. Error:") + save_pkl = True + + if save_pkl: + new_filename = filename + ".pkl" + try: + # retry by pickle + with PathManager.open(new_filename, "wb") as f: + cloudpickle.dump(cfg, f) + logger.warning(f"Config is saved using cloudpickle at {new_filename}.") + except Exception: + pass + + @staticmethod + def apply_overrides(cfg, overrides: List[str]): + """ + In-place override contents of cfg. + + Args: + cfg: an omegaconf config object + overrides: list of strings in the format of "a=b" to override configs. + See https://hydra.cc/docs/next/advanced/override_grammar/basic/ + for syntax. + + Returns: + the cfg object + """ + + def safe_update(cfg, key, value): + parts = key.split(".") + for idx in range(1, len(parts)): + prefix = ".".join(parts[:idx]) + v = OmegaConf.select(cfg, prefix, default=None) + if v is None: + break + if not OmegaConf.is_config(v): + raise KeyError( + f"Trying to update key {key}, but {prefix} " + f"is not a config, but has type {type(v)}." + ) + OmegaConf.update(cfg, key, value, merge=True) + + try: + from hydra.core.override_parser.overrides_parser import OverridesParser + + has_hydra = True + except ImportError: + has_hydra = False + + if has_hydra: + parser = OverridesParser.create() + overrides = parser.parse_overrides(overrides) + for o in overrides: + key = o.key_or_group + value = o.value() + if o.is_delete(): + # TODO support this + raise NotImplementedError("deletion is not yet a supported override") + safe_update(cfg, key, value) + else: + # Fallback. Does not support all the features and error checking like hydra. + for o in overrides: + key, value = o.split("=") + try: + value = eval(value, {}) + except NameError: + pass + safe_update(cfg, key, value) + return cfg + + @staticmethod + def to_py(cfg, prefix: str = "cfg."): + """ + Try to convert a config object into Python-like psuedo code. + + Note that perfect conversion is not always possible. So the returned + results are mainly meant to be human-readable, and not meant to be executed. + + Args: + cfg: an omegaconf config object + prefix: root name for the resulting code (default: "cfg.") + + + Returns: + str of formatted Python code + """ + import black + + cfg = OmegaConf.to_container(cfg, resolve=True) + + def _to_str(obj, prefix=None, inside_call=False): + if prefix is None: + prefix = [] + if isinstance(obj, abc.Mapping) and "_target_" in obj: + # Dict representing a function call + target = _convert_target_to_string(obj.pop("_target_")) + args = [] + for k, v in sorted(obj.items()): + args.append(f"{k}={_to_str(v, inside_call=True)}") + args = ", ".join(args) + call = f"{target}({args})" + return "".join(prefix) + call + elif isinstance(obj, abc.Mapping) and not inside_call: + # Dict that is not inside a call is a list of top-level config objects that we + # render as one object per line with dot separated prefixes + key_list = [] + for k, v in sorted(obj.items()): + if isinstance(v, abc.Mapping) and "_target_" not in v: + key_list.append(_to_str(v, prefix=prefix + [k + "."])) + else: + key = "".join(prefix) + k + key_list.append(f"{key}={_to_str(v)}") + return "\n".join(key_list) + elif isinstance(obj, abc.Mapping): + # Dict that is inside a call is rendered as a regular dict + return ( + "{" + + ",".join( + f"{repr(k)}: {_to_str(v, inside_call=inside_call)}" + for k, v in sorted(obj.items()) + ) + + "}" + ) + elif isinstance(obj, list): + return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]" + else: + return repr(obj) + + py_str = _to_str(cfg, prefix=[prefix]) + try: + return black.format_str(py_str, mode=black.Mode()) + except black.InvalidInput: + return py_str diff --git a/data_processing/detectron2/detectron2/data/__init__.py b/data_processing/detectron2/detectron2/data/__init__.py new file mode 100644 index 0000000..259f669 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import transforms # isort:skip + +from .build import ( + build_batch_data_loader, + build_detection_test_loader, + build_detection_train_loader, + get_detection_dataset_dicts, + load_proposals_into_dataset, + print_instances_class_histogram, +) +from .catalog import DatasetCatalog, MetadataCatalog, Metadata +from .common import DatasetFromList, MapDataset, ToIterableDataset +from .dataset_mapper import DatasetMapper + +# ensure the builtin datasets are registered +from . import datasets, samplers # isort:skip + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/data/benchmark.py b/data_processing/detectron2/detectron2/data/benchmark.py new file mode 100644 index 0000000..ac2f372 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/benchmark.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from itertools import count +from typing import List, Tuple +import torch +import tqdm +from fvcore.common.timer import Timer + +from detectron2.utils import comm + +from .build import build_batch_data_loader +from .common import DatasetFromList, MapDataset +from .samplers import TrainingSampler + +logger = logging.getLogger(__name__) + + +class _EmptyMapDataset(torch.utils.data.Dataset): + """ + Map anything to emptiness. + """ + + def __init__(self, dataset): + self.ds = dataset + + def __len__(self): + return len(self.ds) + + def __getitem__(self, idx): + _ = self.ds[idx] + return [0] + + +def iter_benchmark( + iterator, num_iter: int, warmup: int = 5, max_time_seconds: float = 60 +) -> Tuple[float, List[float]]: + """ + Benchmark an iterator/iterable for `num_iter` iterations with an extra + `warmup` iterations of warmup. + End early if `max_time_seconds` time is spent on iterations. + + Returns: + float: average time (seconds) per iteration + list[float]: time spent on each iteration. Sometimes useful for further analysis. + """ + num_iter, warmup = int(num_iter), int(warmup) + + iterator = iter(iterator) + for _ in range(warmup): + next(iterator) + timer = Timer() + all_times = [] + for curr_iter in tqdm.trange(num_iter): + start = timer.seconds() + if start > max_time_seconds: + num_iter = curr_iter + break + next(iterator) + all_times.append(timer.seconds() - start) + avg = timer.seconds() / num_iter + return avg, all_times + + +class DataLoaderBenchmark: + """ + Some common benchmarks that help understand perf bottleneck of a standard dataloader + made of dataset, mapper and sampler. + """ + + def __init__( + self, + dataset, + *, + mapper, + sampler=None, + total_batch_size, + num_workers=0, + max_time_seconds: int = 90, + ): + """ + Args: + max_time_seconds (int): maximum time to spent for each benchmark + other args: same as in `build.py:build_detection_train_loader` + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False, serialize=True) + if sampler is None: + sampler = TrainingSampler(len(dataset)) + + self.dataset = dataset + self.mapper = mapper + self.sampler = sampler + self.total_batch_size = total_batch_size + self.num_workers = num_workers + self.per_gpu_batch_size = self.total_batch_size // comm.get_world_size() + + self.max_time_seconds = max_time_seconds + + def _benchmark(self, iterator, num_iter, warmup, msg=None): + avg, all_times = iter_benchmark(iterator, num_iter, warmup, self.max_time_seconds) + if msg is not None: + self._log_time(msg, avg, all_times) + return avg, all_times + + def _log_time(self, msg, avg, all_times, distributed=False): + percentiles = [np.percentile(all_times, k, interpolation="nearest") for k in [1, 5, 95, 99]] + if not distributed: + logger.info( + f"{msg}: avg={1.0/avg:.1f} it/s, " + f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, " + f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s." + ) + return + avg_per_gpu = comm.all_gather(avg) + percentiles_per_gpu = comm.all_gather(percentiles) + if comm.get_rank() > 0: + return + for idx, avg, percentiles in zip(count(), avg_per_gpu, percentiles_per_gpu): + logger.info( + f"GPU{idx} {msg}: avg={1.0/avg:.1f} it/s, " + f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, " + f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s." + ) + + def benchmark_dataset(self, num_iter, warmup=5): + """ + Benchmark the speed of taking raw samples from the dataset. + """ + + def loader(): + while True: + for k in self.sampler: + yield self.dataset[k] + + self._benchmark(loader(), num_iter, warmup, "Dataset Alone") + + def benchmark_mapper(self, num_iter, warmup=5): + """ + Benchmark the speed of taking raw samples from the dataset and map + them in a single process. + """ + + def loader(): + while True: + for k in self.sampler: + yield self.mapper(self.dataset[k]) + + self._benchmark(loader(), num_iter, warmup, "Single Process Mapper (sec/sample)") + + def benchmark_workers(self, num_iter, warmup=10): + """ + Benchmark the dataloader by tuning num_workers to [0, 1, self.num_workers]. + """ + candidates = [0, 1] + if self.num_workers not in candidates: + candidates.append(self.num_workers) + + dataset = MapDataset(self.dataset, self.mapper) + for n in candidates: + loader = build_batch_data_loader( + dataset, + self.sampler, + self.total_batch_size, + num_workers=n, + ) + self._benchmark( + iter(loader), + num_iter * max(n, 1), + warmup * max(n, 1), + f"DataLoader ({n} workers, bs={self.per_gpu_batch_size})", + ) + del loader + + def benchmark_IPC(self, num_iter, warmup=10): + """ + Benchmark the dataloader where each worker outputs nothing. This + eliminates the IPC overhead compared to the regular dataloader. + + PyTorch multiprocessing's IPC only optimizes for torch tensors. + Large numpy arrays or other data structure may incur large IPC overhead. + """ + n = self.num_workers + dataset = _EmptyMapDataset(MapDataset(self.dataset, self.mapper)) + loader = build_batch_data_loader( + dataset, self.sampler, self.total_batch_size, num_workers=n + ) + self._benchmark( + iter(loader), + num_iter * max(n, 1), + warmup * max(n, 1), + f"DataLoader ({n} workers, bs={self.per_gpu_batch_size}) w/o comm", + ) + + def benchmark_distributed(self, num_iter, warmup=10): + """ + Benchmark the dataloader in each distributed worker, and log results of + all workers. This helps understand the final performance as well as + the variances among workers. + + It also prints startup time (first iter) of the dataloader. + """ + gpu = comm.get_world_size() + dataset = MapDataset(self.dataset, self.mapper) + n = self.num_workers + loader = build_batch_data_loader( + dataset, self.sampler, self.total_batch_size, num_workers=n + ) + + timer = Timer() + loader = iter(loader) + next(loader) + startup_time = timer.seconds() + logger.info("Dataloader startup time: {:.2f} seconds".format(startup_time)) + + comm.synchronize() + + avg, all_times = self._benchmark(loader, num_iter * max(n, 1), warmup * max(n, 1)) + del loader + self._log_time( + f"DataLoader ({gpu} GPUs x {n} workers, total bs={self.total_batch_size})", + avg, + all_times, + True, + ) diff --git a/data_processing/detectron2/detectron2/data/build.py b/data_processing/detectron2/detectron2/data/build.py new file mode 100644 index 0000000..3fa2c6b --- /dev/null +++ b/data_processing/detectron2/detectron2/data/build.py @@ -0,0 +1,556 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import numpy as np +import operator +import pickle +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import torch.utils.data as torchdata +from tabulate import tabulate +from termcolor import colored + +from detectron2.config import configurable +from detectron2.structures import BoxMode +from detectron2.utils.comm import get_world_size +from detectron2.utils.env import seed_all_rng +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import _log_api_usage, log_first_n + +from .catalog import DatasetCatalog, MetadataCatalog +from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset +from .dataset_mapper import DatasetMapper +from .detection_utils import check_metadata_consistency +from .samplers import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) + +""" +This file contains the default logic to build a dataloader for training or testing. +""" + +__all__ = [ + "build_batch_data_loader", + "build_detection_train_loader", + "build_detection_test_loader", + "get_detection_dataset_dicts", + "load_proposals_into_dataset", + "print_instances_class_histogram", +] + + +def filter_images_with_only_crowd_annotations(dataset_dicts): + """ + Filter out images with none annotations or only crowd annotations + (i.e., images without non-crowd annotations). + A common training-time preprocessing on COCO dataset. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format, but filtered. + """ + num_before = len(dataset_dicts) + + def valid(anns): + for ann in anns: + if ann.get("iscrowd", 0) == 0: + return True + return False + + dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with no usable annotations. {} images left.".format( + num_before - num_after, num_after + ) + ) + return dataset_dicts + + +def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): + """ + Filter out images with too few number of keypoints. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format as dataset_dicts, but filtered. + """ + num_before = len(dataset_dicts) + + def visible_keypoints_in_image(dic): + # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility + annotations = dic["annotations"] + return sum( + (np.array(ann["keypoints"][2::3]) > 0).sum() + for ann in annotations + if "keypoints" in ann + ) + + dataset_dicts = [ + x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image + ] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with fewer than {} keypoints.".format( + num_before - num_after, min_keypoints_per_image + ) + ) + return dataset_dicts + + +def load_proposals_into_dataset(dataset_dicts, proposal_file): + """ + Load precomputed object proposals into the dataset. + + The proposal file should be a pickled dict with the following keys: + + - "ids": list[int] or list[str], the image ids + - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id + - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores + corresponding to the boxes. + - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + proposal_file (str): file path of pre-computed proposals, in pkl format. + + Returns: + list[dict]: the same format as dataset_dicts, but added proposal field. + """ + logger = logging.getLogger(__name__) + logger.info("Loading proposals from: {}".format(proposal_file)) + + with PathManager.open(proposal_file, "rb") as f: + proposals = pickle.load(f, encoding="latin1") + + # Rename the key names in D1 proposal files + rename_keys = {"indexes": "ids", "scores": "objectness_logits"} + for key in rename_keys: + if key in proposals: + proposals[rename_keys[key]] = proposals.pop(key) + + # Fetch the indexes of all proposals that are in the dataset + # Convert image_id to str since they could be int. + img_ids = set({str(record["image_id"]) for record in dataset_dicts}) + id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} + + # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' + bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS + + for record in dataset_dicts: + # Get the index of the proposal + i = id_to_index[str(record["image_id"])] + + boxes = proposals["boxes"][i] + objectness_logits = proposals["objectness_logits"][i] + # Sort the proposals in descending order of the scores + inds = objectness_logits.argsort()[::-1] + record["proposal_boxes"] = boxes[inds] + record["proposal_objectness_logits"] = objectness_logits[inds] + record["proposal_bbox_mode"] = bbox_mode + + return dataset_dicts + + +def print_instances_class_histogram(dataset_dicts, class_names): + """ + Args: + dataset_dicts (list[dict]): list of dataset dicts. + class_names (list[str]): list of class names (zero-indexed). + """ + num_classes = len(class_names) + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + for entry in dataset_dicts: + annos = entry["annotations"] + classes = np.asarray( + [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int + ) + if len(classes): + assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" + assert ( + classes.max() < num_classes + ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" + histogram += np.histogram(classes, bins=hist_bins)[0] + + N_COLS = min(6, len(class_names) * 2) + + def short_name(x): + # make long class names shorter. useful for lvis + if len(x) > 13: + return x[:11] + ".." + return x + + data = list( + itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) + ) + total_num_instances = sum(data[1::2]) + data.extend([None] * (N_COLS - (len(data) % N_COLS))) + if num_classes > 1: + data.extend(["total", total_num_instances]) + data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + data, + headers=["category", "#instances"] * (N_COLS // 2), + tablefmt="pipe", + numalign="left", + stralign="center", + ) + log_first_n( + logging.INFO, + "Distribution of instances among all {} categories:\n".format(num_classes) + + colored(table, "cyan"), + key="message", + ) + + +def get_detection_dataset_dicts( + names, + filter_empty=True, + min_keypoints=0, + proposal_files=None, + check_consistency=True, +): + """ + Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. + + Args: + names (str or list[str]): a dataset name or a list of dataset names + filter_empty (bool): whether to filter out images without instance annotations + min_keypoints (int): filter out images with fewer keypoints than + `min_keypoints`. Set to 0 to do nothing. + proposal_files (list[str]): if given, a list of object proposal files + that match each dataset in `names`. + check_consistency (bool): whether to check if datasets have consistent metadata. + + Returns: + list[dict]: a list of dicts following the standard dataset dict format. + """ + if isinstance(names, str): + names = [names] + assert len(names), names + dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] + + if isinstance(dataset_dicts[0], torchdata.Dataset): + if len(dataset_dicts) > 1: + # ConcatDataset does not work for iterable style dataset. + # We could support concat for iterable as well, but it's often + # not a good idea to concat iterables anyway. + return torchdata.ConcatDataset(dataset_dicts) + return dataset_dicts[0] + + for dataset_name, dicts in zip(names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + if proposal_files is not None: + assert len(names) == len(proposal_files) + # load precomputed proposals from proposal files + dataset_dicts = [ + load_proposals_into_dataset(dataset_i_dicts, proposal_file) + for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) + ] + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + if check_consistency and has_instances: + try: + class_names = MetadataCatalog.get(names[0]).thing_classes + check_metadata_consistency("thing_classes", names) + print_instances_class_histogram(dataset_dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) + return dataset_dicts + + +def build_batch_data_loader( + dataset, + sampler, + total_batch_size, + *, + aspect_ratio_grouping=False, + num_workers=0, + collate_fn=None, +): + """ + Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are: + 1. support aspect ratio grouping options + 2. use no "batch collation", because this is common for detection training + + Args: + dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices. + Must be provided iff. ``dataset`` is a map-style dataset. + total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see + :func:`build_detection_train_loader`. + + Returns: + iterable[list]. Length of each list is the batch size of the current + GPU. Each element in the list comes from the dataset. + """ + world_size = get_world_size() + assert ( + total_batch_size > 0 and total_batch_size % world_size == 0 + ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( + total_batch_size, world_size + ) + batch_size = total_batch_size // world_size + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + dataset = ToIterableDataset(dataset, sampler) + + if aspect_ratio_grouping: + data_loader = torchdata.DataLoader( + dataset, + num_workers=num_workers, + collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements + worker_init_fn=worker_init_reset_seed, + ) # yield individual mapped dict + data_loader = AspectRatioGroupedDataset(data_loader, batch_size) + if collate_fn is None: + return data_loader + return MapDataset(data_loader, collate_fn) + else: + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + worker_init_fn=worker_init_reset_seed, + ) + + +def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): + if dataset is None: + dataset = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is None: + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + if isinstance(dataset, torchdata.IterableDataset): + logger.info("Not using any sampler since the dataset is IterableDataset.") + sampler = None + else: + logger.info("Using training sampler {}".format(sampler_name)) + if sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset, cfg.DATALOADER.REPEAT_THRESHOLD + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == "RandomSubsetTrainingSampler": + sampler = RandomSubsetTrainingSampler( + len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO + ) + else: + raise ValueError("Unknown training sampler: {}".format(sampler_name)) + + return { + "dataset": dataset, + "sampler": sampler, + "mapper": mapper, + "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, + "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + } + + +@configurable(from_config=_train_loader_from_config) +def build_detection_train_loader( + dataset, + *, + mapper, + sampler=None, + total_batch_size, + aspect_ratio_grouping=True, + num_workers=0, + collate_fn=None, +): + """ + Build a dataloader for object detection with some default features. + + Args: + dataset (list or torch.utils.data.Dataset): a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). It can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper (callable): a callable which takes a sample (dict) from dataset and + returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces + indices to be applied on ``dataset``. + If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`, + which coordinates an infinite random shuffle sequence across all workers. + Sampler must be None if ``dataset`` is iterable. + total_batch_size (int): total batch size across all workers. + aspect_ratio_grouping (bool): whether to group images with similar + aspect ratio for efficiency. When enabled, it requires each + element in dataset be a dict with keys "width" and "height". + num_workers (int): number of parallel data loading workers + collate_fn: a function that determines how to do batching, same as the argument of + `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of + data. No collation is OK for small batch size and simple data structures. + If your batch size is large and each sample contains too many small tensors, + it's more efficient to collate them in data loader. + + Returns: + torch.utils.data.DataLoader: + a dataloader. Each output from it is a ``list[mapped_element]`` of length + ``total_batch_size / num_workers``, where ``mapped_element`` is produced + by the ``mapper``. + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}" + return build_batch_data_loader( + dataset, + sampler, + total_batch_size, + aspect_ratio_grouping=aspect_ratio_grouping, + num_workers=num_workers, + collate_fn=collate_fn, + ) + + +def _test_loader_from_config(cfg, dataset_name, mapper=None): + """ + Uses the given `dataset_name` argument (instead of the names in cfg), because the + standard practice is to evaluate each test set individually (not combining them). + """ + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + dataset = get_detection_dataset_dicts( + dataset_name, + filter_empty=False, + proposal_files=[ + cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name + ] + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + if mapper is None: + mapper = DatasetMapper(cfg, False) + return { + "dataset": dataset, + "mapper": mapper, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + "sampler": InferenceSampler(len(dataset)) + if not isinstance(dataset, torchdata.IterableDataset) + else None, + } + + +@configurable(from_config=_test_loader_from_config) +def build_detection_test_loader( + dataset: Union[List[Any], torchdata.Dataset], + *, + mapper: Callable[[Dict[str, Any]], Any], + sampler: Optional[torchdata.Sampler] = None, + batch_size: int = 1, + num_workers: int = 0, + collate_fn: Optional[Callable[[List[Any]], Any]] = None, +) -> torchdata.DataLoader: + """ + Similar to `build_detection_train_loader`, with default batch size = 1, + and sampler = :class:`InferenceSampler`. This sampler coordinates all workers + to produce the exact set of all samples. + + Args: + dataset: a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). They can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper: a callable which takes a sample (dict) from dataset + and returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. + sampler: a sampler that produces + indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, + which splits the dataset across all workers. Sampler must be None + if `dataset` is iterable. + batch_size: the batch size of the data loader to be created. + Default to 1 image per worker since this is the standard when reporting + inference time in papers. + num_workers: number of parallel data loading workers + collate_fn: same as the argument of `torch.utils.data.DataLoader`. + Defaults to do no collation and return a list of data. + + Returns: + DataLoader: a torch DataLoader, that loads the given detection + dataset, with test-time transformation and batching. + + Examples: + :: + data_loader = build_detection_test_loader( + DatasetRegistry.get("my_test"), + mapper=DatasetMapper(...)) + + # or, instantiate with a CfgNode: + data_loader = build_detection_test_loader(cfg, "my_test") + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = InferenceSampler(len(dataset)) + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + drop_last=False, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + ) + + +def trivial_batch_collator(batch): + """ + A batch collator that does nothing. + """ + return batch + + +def worker_init_reset_seed(worker_id): + initial_seed = torch.initial_seed() % 2**31 + seed_all_rng(initial_seed + worker_id) diff --git a/data_processing/detectron2/detectron2/data/catalog.py b/data_processing/detectron2/detectron2/data/catalog.py new file mode 100644 index 0000000..45c110c --- /dev/null +++ b/data_processing/detectron2/detectron2/data/catalog.py @@ -0,0 +1,236 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import types +from collections import UserDict +from typing import List + +from detectron2.utils.logger import log_first_n + +__all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"] + + +class _DatasetCatalog(UserDict): + """ + A global dictionary that stores information about the datasets and how to obtain them. + + It contains a mapping from strings + (which are names that identify a dataset, e.g. "coco_2014_train") + to a function which parses the dataset and returns the samples in the + format of `list[dict]`. + + The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details) + if used with the data loader functionalities in `data/build.py,data/detection_transform.py`. + + The purpose of having this catalog is to make it easy to choose + different datasets, by just using the strings in the config. + """ + + def register(self, name, func): + """ + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + func (callable): a callable which takes no arguments and returns a list of dicts. + It must return the same results if called multiple times. + """ + assert callable(func), "You must register a function with `DatasetCatalog.register`!" + assert name not in self, "Dataset '{}' is already registered!".format(name) + self[name] = func + + def get(self, name): + """ + Call the registered function and return its results. + + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + + Returns: + list[dict]: dataset annotations. + """ + try: + f = self[name] + except KeyError as e: + raise KeyError( + "Dataset '{}' is not registered! Available datasets are: {}".format( + name, ", ".join(list(self.keys())) + ) + ) from e + return f() + + def list(self) -> List[str]: + """ + List all registered datasets. + + Returns: + list[str] + """ + return list(self.keys()) + + def remove(self, name): + """ + Alias of ``pop``. + """ + self.pop(name) + + def __str__(self): + return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys())) + + __repr__ = __str__ + + +DatasetCatalog = _DatasetCatalog() +DatasetCatalog.__doc__ = ( + _DatasetCatalog.__doc__ + + """ + .. automethod:: detectron2.data.catalog.DatasetCatalog.register + .. automethod:: detectron2.data.catalog.DatasetCatalog.get +""" +) + + +class Metadata(types.SimpleNamespace): + """ + A class that supports simple attribute setter/getter. + It is intended for storing metadata of a dataset and make it accessible globally. + + Examples: + :: + # somewhere when you load the data: + MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"] + + # somewhere when you print statistics or visualize: + classes = MetadataCatalog.get("mydataset").thing_classes + """ + + # the name of the dataset + # set default to N/A so that `self.name` in the errors will not trigger getattr again + name: str = "N/A" + + _RENAMED = { + "class_names": "thing_classes", + "dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id", + "stuff_class_names": "stuff_classes", + } + + def __getattr__(self, key): + if key in self._RENAMED: + log_first_n( + logging.WARNING, + "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]), + n=10, + ) + return getattr(self, self._RENAMED[key]) + + # "name" exists in every metadata + if len(self.__dict__) > 1: + raise AttributeError( + "Attribute '{}' does not exist in the metadata of dataset '{}'. Available " + "keys are {}.".format(key, self.name, str(self.__dict__.keys())) + ) + else: + raise AttributeError( + f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': " + "metadata is empty." + ) + + def __setattr__(self, key, val): + if key in self._RENAMED: + log_first_n( + logging.WARNING, + "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]), + n=10, + ) + setattr(self, self._RENAMED[key], val) + + # Ensure that metadata of the same name stays consistent + try: + oldval = getattr(self, key) + assert oldval == val, ( + "Attribute '{}' in the metadata of '{}' cannot be set " + "to a different value!\n{} != {}".format(key, self.name, oldval, val) + ) + except AttributeError: + super().__setattr__(key, val) + + def as_dict(self): + """ + Returns all the metadata as a dict. + Note that modifications to the returned dict will not reflect on the Metadata object. + """ + return copy.copy(self.__dict__) + + def set(self, **kwargs): + """ + Set multiple metadata with kwargs. + """ + for k, v in kwargs.items(): + setattr(self, k, v) + return self + + def get(self, key, default=None): + """ + Access an attribute and return its value if exists. + Otherwise return default. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + +class _MetadataCatalog(UserDict): + """ + MetadataCatalog is a global dictionary that provides access to + :class:`Metadata` of a given dataset. + + The metadata associated with a certain name is a singleton: once created, the + metadata will stay alive and will be returned by future calls to ``get(name)``. + + It's like global variables, so don't abuse it. + It's meant for storing knowledge that's constant and shared across the execution + of the program, e.g.: the class names in COCO. + """ + + def get(self, name): + """ + Args: + name (str): name of a dataset (e.g. coco_2014_train). + + Returns: + Metadata: The :class:`Metadata` instance associated with this name, + or create an empty one if none is available. + """ + assert len(name) + r = super().get(name, None) + if r is None: + r = self[name] = Metadata(name=name) + return r + + def list(self): + """ + List all registered metadata. + + Returns: + list[str]: keys (names of datasets) of all registered metadata + """ + return list(self.keys()) + + def remove(self, name): + """ + Alias of ``pop``. + """ + self.pop(name) + + def __str__(self): + return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys())) + + __repr__ = __str__ + + +MetadataCatalog = _MetadataCatalog() +MetadataCatalog.__doc__ = ( + _MetadataCatalog.__doc__ + + """ + .. automethod:: detectron2.data.catalog.MetadataCatalog.get +""" +) diff --git a/data_processing/detectron2/detectron2/data/common.py b/data_processing/detectron2/detectron2/data/common.py new file mode 100644 index 0000000..bf24b1d --- /dev/null +++ b/data_processing/detectron2/detectron2/data/common.py @@ -0,0 +1,301 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import itertools +import logging +import numpy as np +import pickle +import random +from typing import Callable, Union +import torch +import torch.utils.data as data +from torch.utils.data.sampler import Sampler + +from detectron2.utils.serialize import PicklableWrapper + +__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"] + +logger = logging.getLogger(__name__) + + +def _shard_iterator_dataloader_worker(iterable): + # Shard the iterable if we're currently inside pytorch dataloader worker. + worker_info = data.get_worker_info() + if worker_info is None or worker_info.num_workers == 1: + # do nothing + yield from iterable + else: + yield from itertools.islice(iterable, worker_info.id, None, worker_info.num_workers) + + +class _MapIterableDataset(data.IterableDataset): + """ + Map a function over elements in an IterableDataset. + + Similar to pytorch's MapIterDataPipe, but support filtering when map_func + returns None. + + This class is not public-facing. Will be called by `MapDataset`. + """ + + def __init__(self, dataset, map_func): + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + def __len__(self): + return len(self._dataset) + + def __iter__(self): + for x in map(self._map_func, self._dataset): + if x is not None: + yield x + + +class MapDataset(data.Dataset): + """ + Map a function over the elements in a dataset. + """ + + def __init__(self, dataset, map_func): + """ + Args: + dataset: a dataset where map function is applied. Can be either + map-style or iterable dataset. When given an iterable dataset, + the returned object will also be an iterable dataset. + map_func: a callable which maps the element in dataset. map_func can + return None to skip the data (e.g. in case of errors). + How None is handled depends on the style of `dataset`. + If `dataset` is map-style, it randomly tries other elements. + If `dataset` is iterable, it skips the data and tries the next. + """ + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + self._rng = random.Random(42) + self._fallback_candidates = set(range(len(dataset))) + + def __new__(cls, dataset, map_func): + is_iterable = isinstance(dataset, data.IterableDataset) + if is_iterable: + return _MapIterableDataset(dataset, map_func) + else: + return super().__new__(cls) + + def __getnewargs__(self): + return self._dataset, self._map_func + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + retry_count = 0 + cur_idx = int(idx) + + while True: + data = self._map_func(self._dataset[cur_idx]) + if data is not None: + self._fallback_candidates.add(cur_idx) + return data + + # _map_func fails for this idx, use a random new index from the pool + retry_count += 1 + self._fallback_candidates.discard(cur_idx) + cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] + + if retry_count >= 3: + logger = logging.getLogger(__name__) + logger.warning( + "Failed to apply `_map_func` for idx: {}, retry count: {}".format( + idx, retry_count + ) + ) + + +class _TorchSerializedList(object): + """ + A list-like object whose items are serialized and stored in a torch tensor. When + launching a process that uses TorchSerializedList with "fork" start method, + the subprocess can read the same buffer without triggering copy-on-access. When + launching a process that uses TorchSerializedList with "spawn/forkserver" start + method, the list will be pickled by a special ForkingPickler registered by PyTorch + that moves data to shared memory. In both cases, this allows parent and child + processes to share RAM for the list data, hence avoids the issue in + https://github.com/pytorch/pytorch/issues/13246. + + See also https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/ + on how it works. + """ + + def __init__(self, lst: list): + self._lst = lst + + def _serialize(data): + buffer = pickle.dumps(data, protocol=-1) + return np.frombuffer(buffer, dtype=np.uint8) + + logger.info( + "Serializing {} elements to byte tensors and concatenating them all ...".format( + len(self._lst) + ) + ) + self._lst = [_serialize(x) for x in self._lst] + self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) + self._addr = torch.from_numpy(np.cumsum(self._addr)) + self._lst = torch.from_numpy(np.concatenate(self._lst)) + logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) + + def __len__(self): + return len(self._addr) + + def __getitem__(self, idx): + start_addr = 0 if idx == 0 else self._addr[idx - 1].item() + end_addr = self._addr[idx].item() + bytes = memoryview(self._lst[start_addr:end_addr].numpy()) + + # @lint-ignore PYTHONPICKLEISBAD + return pickle.loads(bytes) + + +_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = _TorchSerializedList + + +@contextlib.contextmanager +def set_default_dataset_from_list_serialize_method(new): + """ + Context manager for using custom serialize function when creating DatasetFromList + """ + + global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new + yield + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig + + +class DatasetFromList(data.Dataset): + """ + Wrap a list to a torch Dataset. It produces elements of the list as data. + """ + + def __init__( + self, + lst: list, + copy: bool = True, + serialize: Union[bool, Callable] = True, + ): + """ + Args: + lst (list): a list which contains elements to produce. + copy (bool): whether to deepcopy the element when producing it, + so that the result can be modified in place without affecting the + source in the list. + serialize (bool or callable): whether to serialize the stroage to other + backend. If `True`, the default serialize method will be used, if given + a callable, the callable will be used as serialize method. + """ + self._lst = lst + self._copy = copy + if not isinstance(serialize, (bool, Callable)): + raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}") + self._serialize = serialize is not False + + if self._serialize: + serialize_method = ( + serialize + if isinstance(serialize, Callable) + else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + ) + logger.info(f"Serializing the dataset using: {serialize_method}") + self._lst = serialize_method(self._lst) + + def __len__(self): + return len(self._lst) + + def __getitem__(self, idx): + if self._copy and not self._serialize: + return copy.deepcopy(self._lst[idx]) + else: + return self._lst[idx] + + +class ToIterableDataset(data.IterableDataset): + """ + Convert an old indices-based (also called map-style) dataset + to an iterable-style dataset. + """ + + def __init__(self, dataset: data.Dataset, sampler: Sampler, shard_sampler: bool = True): + """ + Args: + dataset: an old-style dataset with ``__getitem__`` + sampler: a cheap iterable that produces indices to be applied on ``dataset``. + shard_sampler: whether to shard the sampler based on the current pytorch data loader + worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple + workers, it is responsible for sharding its data based on worker id so that workers + don't produce identical data. + + Most samplers (like our TrainingSampler) do not shard based on dataloader worker id + and this argument should be set to True. But certain samplers may be already + sharded, in that case this argument should be set to False. + """ + assert not isinstance(dataset, data.IterableDataset), dataset + assert isinstance(sampler, Sampler), sampler + self.dataset = dataset + self.sampler = sampler + self.shard_sampler = shard_sampler + + def __iter__(self): + if not self.shard_sampler: + sampler = self.sampler + else: + # With map-style dataset, `DataLoader(dataset, sampler)` runs the + # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))` + # will run sampler in every of the N worker. So we should only keep 1/N of the ids on + # each worker. The assumption is that sampler is cheap to iterate so it's fine to + # discard ids in workers. + sampler = _shard_iterator_dataloader_worker(self.sampler) + for idx in sampler: + yield self.dataset[idx] + + def __len__(self): + return len(self.sampler) + + +class AspectRatioGroupedDataset(data.IterableDataset): + """ + Batch data that have similar aspect ratio together. + In this implementation, images whose aspect ratio < (or >) 1 will + be batched together. + This improves training speed because the images then need less padding + to form a batch. + + It assumes the underlying dataset produces dicts with "width" and "height" keys. + It will then produce a list of original dicts with length = batch_size, + all with similar aspect ratios. + """ + + def __init__(self, dataset, batch_size): + """ + Args: + dataset: an iterable. Each element must be a dict with keys + "width" and "height", which will be used to batch data. + batch_size (int): + """ + self.dataset = dataset + self.batch_size = batch_size + self._buckets = [[] for _ in range(2)] + # Hard-coded two aspect ratio groups: w > h and w < h. + # Can add support for more aspect ratio groups, but doesn't seem useful + + def __iter__(self): + for d in self.dataset: + w, h = d["width"], d["height"] + bucket_id = 0 if w > h else 1 + bucket = self._buckets[bucket_id] + bucket.append(d) + if len(bucket) == self.batch_size: + data = bucket[:] + # Clear bucket first, because code after yield is not + # guaranteed to execute + del bucket[:] + yield data diff --git a/data_processing/detectron2/detectron2/data/dataset_mapper.py b/data_processing/detectron2/detectron2/data/dataset_mapper.py new file mode 100644 index 0000000..a8714f7 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/dataset_mapper.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import numpy as np +from typing import List, Optional, Union +import torch + +from detectron2.config import configurable + +from . import detection_utils as utils +from . import transforms as T + +""" +This file contains the default mapping that's applied to "dataset dicts". +""" + +__all__ = ["DatasetMapper"] + + +class DatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by the model. + + This is the default callable to be used to map your dataset dict into training data. + You may need to follow it to implement your own one for customized logic, + such as a different way to read or transform images. + See :doc:`/tutorials/data_loading` for details. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies cropping/geometric transforms to the image and annotations + 3. Prepare data and annotations to Tensor and :class:`Instances` + """ + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON + else: + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( + cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN + if is_train + else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST + ) + return ret + + def _transform_annotations(self, dataset_dict, transforms, image_shape): + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + # After transforms such as cropping are applied, the bounding box may no longer + # tightly bound the object. As an example, imagine a triangle object + # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight + # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to + # the intersection of original bounding box and the cropping box. + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # USER: Write your own image loading if it's not from a file + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + utils.check_image_size(dataset_dict, image) + + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) + else: + sem_seg_gt = None + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + # USER: Remove if you don't use pre-computed proposals. + # Most users would not need this feature. + if self.proposal_topk is not None: + utils.transform_proposals( + dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk + ) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + dataset_dict.pop("sem_seg_file_name", None) + return dataset_dict + + if "annotations" in dataset_dict: + self._transform_annotations(dataset_dict, transforms, image_shape) + + return dataset_dict diff --git a/data_processing/detectron2/detectron2/data/datasets/README.md b/data_processing/detectron2/detectron2/data/datasets/README.md new file mode 100644 index 0000000..9fb3e4f --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/README.md @@ -0,0 +1,9 @@ + + +### Common Datasets + +The dataset implemented here do not need to load the data into the final format. +It should provide the minimal data structure needed to use the dataset, so it can be very efficient. + +For example, for an image dataset, just provide the file names and labels, but don't read the images. +Let the downstream decide how to read. diff --git a/data_processing/detectron2/detectron2/data/datasets/__init__.py b/data_processing/detectron2/detectron2/data/datasets/__init__.py new file mode 100644 index 0000000..a44bedc --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json +from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated +from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta +from .pascal_voc import load_voc_instances, register_pascal_voc +from . import builtin as _builtin # ensure the builtin datasets are registered + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/data/datasets/builtin.py b/data_processing/detectron2/detectron2/data/datasets/builtin.py new file mode 100644 index 0000000..c3a68aa --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/builtin.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + + +""" +This file registers pre-defined datasets at hard-coded paths, and their metadata. + +We hard-code metadata for common datasets. This will enable: +1. Consistency check when loading the datasets +2. Use models on these standard datasets directly and run demos, + without having to download the dataset annotations + +We hard-code some paths to the dataset that's assumed to +exist in "./datasets/". + +Users SHOULD NOT use this file to create new dataset / metadata for new dataset. +To add new dataset, refer to the tutorial "docs/DATASETS.md". +""" + +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog + +from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata +from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic +from .cityscapes_panoptic import register_all_cityscapes_panoptic +from .coco import load_sem_seg, register_coco_instances +from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated +from .lvis import get_lvis_instances_meta, register_lvis_instances +from .pascal_voc import register_pascal_voc + +# ==== Predefined datasets and splits for COCO ========== + +_PREDEFINED_SPLITS_COCO = {} +_PREDEFINED_SPLITS_COCO["coco"] = { + "coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"), + "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"), + "coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"), + "coco_2014_valminusminival": ( + "coco/val2014", + "coco/annotations/instances_valminusminival2014.json", + ), + "coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"), + "coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"), + "coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"), + "coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"), + "coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"), +} + +_PREDEFINED_SPLITS_COCO["coco_person"] = { + "keypoints_coco_2014_train": ( + "coco/train2014", + "coco/annotations/person_keypoints_train2014.json", + ), + "keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"), + "keypoints_coco_2014_minival": ( + "coco/val2014", + "coco/annotations/person_keypoints_minival2014.json", + ), + "keypoints_coco_2014_valminusminival": ( + "coco/val2014", + "coco/annotations/person_keypoints_valminusminival2014.json", + ), + "keypoints_coco_2017_train": ( + "coco/train2017", + "coco/annotations/person_keypoints_train2017.json", + ), + "keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"), + "keypoints_coco_2017_val_100": ( + "coco/val2017", + "coco/annotations/person_keypoints_val2017_100.json", + ), +} + + +_PREDEFINED_SPLITS_COCO_PANOPTIC = { + "coco_2017_train_panoptic": ( + # This is the original panoptic annotation directory + "coco/panoptic_train2017", + "coco/annotations/panoptic_train2017.json", + # This directory contains semantic annotations that are + # converted from panoptic annotations. + # It is used by PanopticFPN. + # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py + # to create these directories. + "coco/panoptic_stuff_train2017", + ), + "coco_2017_val_panoptic": ( + "coco/panoptic_val2017", + "coco/annotations/panoptic_val2017.json", + "coco/panoptic_stuff_val2017", + ), + "coco_2017_val_100_panoptic": ( + "coco/panoptic_val2017_100", + "coco/annotations/panoptic_val2017_100.json", + "coco/panoptic_stuff_val2017_100", + ), +} + + +def register_all_coco(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items(): + for key, (image_root, json_file) in splits_per_dataset.items(): + # Assume pre-defined datasets live in `./datasets`. + register_coco_instances( + key, + _get_builtin_metadata(dataset_name), + os.path.join(root, json_file) if "://" not in json_file else json_file, + os.path.join(root, image_root), + ) + + for ( + prefix, + (panoptic_root, panoptic_json, semantic_root), + ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items(): + prefix_instances = prefix[: -len("_panoptic")] + instances_meta = MetadataCatalog.get(prefix_instances) + image_root, instances_json = instances_meta.image_root, instances_meta.json_file + # The "separated" version of COCO panoptic segmentation dataset, + # e.g. used by Panoptic FPN + register_coco_panoptic_separated( + prefix, + _get_builtin_metadata("coco_panoptic_separated"), + image_root, + os.path.join(root, panoptic_root), + os.path.join(root, panoptic_json), + os.path.join(root, semantic_root), + instances_json, + ) + # The "standard" version of COCO panoptic segmentation dataset, + # e.g. used by Panoptic-DeepLab + register_coco_panoptic( + prefix, + _get_builtin_metadata("coco_panoptic_standard"), + image_root, + os.path.join(root, panoptic_root), + os.path.join(root, panoptic_json), + instances_json, + ) + + +# ==== Predefined datasets and splits for LVIS ========== + + +_PREDEFINED_SPLITS_LVIS = { + "lvis_v1": { + "lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"), + "lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"), + "lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"), + "lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"), + }, + "lvis_v0.5": { + "lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"), + "lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"), + "lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"), + "lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"), + }, + "lvis_v0.5_cocofied": { + "lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"), + "lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"), + }, +} + + +def register_all_lvis(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items(): + for key, (image_root, json_file) in splits_per_dataset.items(): + register_lvis_instances( + key, + get_lvis_instances_meta(dataset_name), + os.path.join(root, json_file) if "://" not in json_file else json_file, + os.path.join(root, image_root), + ) + + +# ==== Predefined splits for raw cityscapes images =========== +_RAW_CITYSCAPES_SPLITS = { + "cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train/", "cityscapes/gtFine/train/"), + "cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val/", "cityscapes/gtFine/val/"), + "cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test/", "cityscapes/gtFine/test/"), +} + + +def register_all_cityscapes(root): + for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items(): + meta = _get_builtin_metadata("cityscapes") + image_dir = os.path.join(root, image_dir) + gt_dir = os.path.join(root, gt_dir) + + inst_key = key.format(task="instance_seg") + DatasetCatalog.register( + inst_key, + lambda x=image_dir, y=gt_dir: load_cityscapes_instances( + x, y, from_json=True, to_polygons=True + ), + ) + MetadataCatalog.get(inst_key).set( + image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta + ) + + sem_key = key.format(task="sem_seg") + DatasetCatalog.register( + sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y) + ) + MetadataCatalog.get(sem_key).set( + image_dir=image_dir, + gt_dir=gt_dir, + evaluator_type="cityscapes_sem_seg", + ignore_label=255, + **meta, + ) + + +# ==== Predefined splits for PASCAL VOC =========== +def register_all_pascal_voc(root): + SPLITS = [ + ("voc_2007_trainval", "VOC2007", "trainval"), + ("voc_2007_train", "VOC2007", "train"), + ("voc_2007_val", "VOC2007", "val"), + ("voc_2007_test", "VOC2007", "test"), + ("voc_2012_trainval", "VOC2012", "trainval"), + ("voc_2012_train", "VOC2012", "train"), + ("voc_2012_val", "VOC2012", "val"), + ] + for name, dirname, split in SPLITS: + year = 2007 if "2007" in name else 2012 + register_pascal_voc(name, os.path.join(root, dirname), split, year) + MetadataCatalog.get(name).evaluator_type = "pascal_voc" + + +def register_all_ade20k(root): + root = os.path.join(root, "ADEChallengeData2016") + for name, dirname in [("train", "training"), ("val", "validation")]: + image_dir = os.path.join(root, "images", dirname) + gt_dir = os.path.join(root, "annotations_detectron2", dirname) + name = f"ade20k_sem_seg_{name}" + DatasetCatalog.register( + name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") + ) + MetadataCatalog.get(name).set( + stuff_classes=ADE20K_SEM_SEG_CATEGORIES[:], + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + ) + + +# True for open source; +# Internally at fb, we register them elsewhere +if __name__.endswith(".builtin"): + # Assume pre-defined datasets live in `./datasets`. + _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) + register_all_coco(_root) + register_all_lvis(_root) + register_all_cityscapes(_root) + register_all_cityscapes_panoptic(_root) + register_all_pascal_voc(_root) + register_all_ade20k(_root) diff --git a/data_processing/detectron2/detectron2/data/datasets/builtin_meta.py b/data_processing/detectron2/detectron2/data/datasets/builtin_meta.py new file mode 100644 index 0000000..63c7a1a --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/builtin_meta.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +Note: +For your custom dataset, there is no need to hard-code metadata anywhere in the code. +For example, for COCO-format dataset, metadata will be obtained automatically +when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways +during loading. + +However, we hard-coded metadata for a few common dataset here. +The only goal is to allow users who don't have these dataset to use pre-trained models. +Users don't have to download a COCO json (which contains metadata), in order to visualize a +COCO model (with correct class names and colors). +""" + + +# All coco categories, together with their nice-looking visualization colors +# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, + {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, + {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, + {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, + {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, + {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, + {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, + {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, + {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, + {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, + {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, + {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, + {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, + {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, + {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, + {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, + {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, + {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, + {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, + {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, + {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, + {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, + {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, + {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, + {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, + {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, + {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, + {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, + {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, + {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, + {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, + {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, + {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, + {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, + {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, + {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, + {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, + {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, + {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, + {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, + {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, + {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, + {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, + {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, + {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, + {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, + {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, + {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, + {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, + {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, + {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, + {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, + {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}, +] + +# fmt: off +COCO_PERSON_KEYPOINT_NAMES = ( + "nose", + "left_eye", "right_eye", + "left_ear", "right_ear", + "left_shoulder", "right_shoulder", + "left_elbow", "right_elbow", + "left_wrist", "right_wrist", + "left_hip", "right_hip", + "left_knee", "right_knee", + "left_ankle", "right_ankle", +) +# fmt: on + +# Pairs of keypoints that should be exchanged under horizontal flipping +COCO_PERSON_KEYPOINT_FLIP_MAP = ( + ("left_eye", "right_eye"), + ("left_ear", "right_ear"), + ("left_shoulder", "right_shoulder"), + ("left_elbow", "right_elbow"), + ("left_wrist", "right_wrist"), + ("left_hip", "right_hip"), + ("left_knee", "right_knee"), + ("left_ankle", "right_ankle"), +) + +# rules for pairs of keypoints to draw a line between, and the line color to use. +KEYPOINT_CONNECTION_RULES = [ + # face + ("left_ear", "left_eye", (102, 204, 255)), + ("right_ear", "right_eye", (51, 153, 255)), + ("left_eye", "nose", (102, 0, 204)), + ("nose", "right_eye", (51, 102, 255)), + # upper-body + ("left_shoulder", "right_shoulder", (255, 128, 0)), + ("left_shoulder", "left_elbow", (153, 255, 204)), + ("right_shoulder", "right_elbow", (128, 229, 255)), + ("left_elbow", "left_wrist", (153, 255, 153)), + ("right_elbow", "right_wrist", (102, 255, 224)), + # lower-body + ("left_hip", "right_hip", (255, 102, 0)), + ("left_hip", "left_knee", (255, 255, 77)), + ("right_hip", "right_knee", (153, 255, 204)), + ("left_knee", "left_ankle", (191, 255, 128)), + ("right_knee", "right_ankle", (255, 195, 77)), +] + +# All Cityscapes categories, together with their nice-looking visualization colors +# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa +CITYSCAPES_CATEGORIES = [ + {"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"}, + {"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"}, + {"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"}, + {"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"}, + {"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"}, + {"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"}, + {"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"}, + {"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"}, + {"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"}, + {"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"}, + {"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"}, + {"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"}, + {"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"}, + {"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"}, + {"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"}, + {"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"}, + {"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"}, + {"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"}, + {"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"}, +] + +# fmt: off +ADE20K_SEM_SEG_CATEGORIES = [ + "wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa +] +# After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore +# fmt: on + + +def _get_coco_instances_meta(): + thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1] + thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] + assert len(thing_ids) == 80, len(thing_ids) + # Mapping from the incontiguous COCO category id to an id in [0, 79] + thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} + thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] + ret = { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + "thing_colors": thing_colors, + } + return ret + + +def _get_coco_panoptic_separated_meta(): + """ + Returns metadata for "separated" version of the panoptic segmentation dataset. + """ + stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0] + assert len(stuff_ids) == 53, len(stuff_ids) + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 53], used in models) to ids in the dataset (used for processing results) + # The id 0 is mapped to an extra category "thing". + stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)} + # When converting COCO panoptic annotations to semantic annotations + # We label the "thing" category to 0 + stuff_dataset_id_to_contiguous_id[0] = 0 + + # 54 names for COCO stuff categories (including "things") + stuff_classes = ["things"] + [ + k["name"].replace("-other", "").replace("-merged", "") + for k in COCO_CATEGORIES + if k["isthing"] == 0 + ] + + # NOTE: I randomly picked a color for things + stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0] + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + "stuff_colors": stuff_colors, + } + ret.update(_get_coco_instances_meta()) + return ret + + +def _get_builtin_metadata(dataset_name): + if dataset_name == "coco": + return _get_coco_instances_meta() + if dataset_name == "coco_panoptic_separated": + return _get_coco_panoptic_separated_meta() + elif dataset_name == "coco_panoptic_standard": + meta = {} + # The following metadata maps contiguous id from [0, #thing categories + + # #stuff categories) to their names and colors. We have to replica of the + # same name and color under "thing_*" and "stuff_*" because the current + # visualization function in D2 handles thing and class classes differently + # due to some heuristic used in Panoptic FPN. We keep the same naming to + # enable reusing existing visualization functions. + thing_classes = [k["name"] for k in COCO_CATEGORIES] + thing_colors = [k["color"] for k in COCO_CATEGORIES] + stuff_classes = [k["name"] for k in COCO_CATEGORIES] + stuff_colors = [k["color"] for k in COCO_CATEGORIES] + + meta["thing_classes"] = thing_classes + meta["thing_colors"] = thing_colors + meta["stuff_classes"] = stuff_classes + meta["stuff_colors"] = stuff_colors + + # Convert category id for training: + # category id: like semantic segmentation, it is the class id for each + # pixel. Since there are some classes not used in evaluation, the category + # id is not always contiguous and thus we have two set of category ids: + # - original category id: category id in the original dataset, mainly + # used for evaluation. + # - contiguous category id: [0, #classes), in order to train the linear + # softmax classifier. + thing_dataset_id_to_contiguous_id = {} + stuff_dataset_id_to_contiguous_id = {} + + for i, cat in enumerate(COCO_CATEGORIES): + if cat["isthing"]: + thing_dataset_id_to_contiguous_id[cat["id"]] = i + else: + stuff_dataset_id_to_contiguous_id[cat["id"]] = i + + meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id + meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id + + return meta + elif dataset_name == "coco_person": + return { + "thing_classes": ["person"], + "keypoint_names": COCO_PERSON_KEYPOINT_NAMES, + "keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP, + "keypoint_connection_rules": KEYPOINT_CONNECTION_RULES, + } + elif dataset_name == "cityscapes": + # fmt: off + CITYSCAPES_THING_CLASSES = [ + "person", "rider", "car", "truck", + "bus", "train", "motorcycle", "bicycle", + ] + CITYSCAPES_STUFF_CLASSES = [ + "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light", + "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", + "truck", "bus", "train", "motorcycle", "bicycle", + ] + # fmt: on + return { + "thing_classes": CITYSCAPES_THING_CLASSES, + "stuff_classes": CITYSCAPES_STUFF_CLASSES, + } + raise KeyError("No built-in metadata for dataset {}".format(dataset_name)) diff --git a/data_processing/detectron2/detectron2/data/datasets/cityscapes.py b/data_processing/detectron2/detectron2/data/datasets/cityscapes.py new file mode 100644 index 0000000..1e84a5b --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/cityscapes.py @@ -0,0 +1,329 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import functools +import json +import logging +import multiprocessing as mp +import numpy as np +import os +from itertools import chain +import pycocotools.mask as mask_util +from PIL import Image + +from detectron2.structures import BoxMode +from detectron2.utils.comm import get_world_size +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + + +logger = logging.getLogger(__name__) + + +def _get_cityscapes_files(image_dir, gt_dir): + files = [] + # scan through the directory + cities = PathManager.ls(image_dir) + logger.info(f"{len(cities)} cities found in '{image_dir}'.") + for city in cities: + city_img_dir = os.path.join(image_dir, city) + city_gt_dir = os.path.join(gt_dir, city) + for basename in PathManager.ls(city_img_dir): + image_file = os.path.join(city_img_dir, basename) + + suffix = "leftImg8bit.png" + assert basename.endswith(suffix), basename + basename = basename[: -len(suffix)] + + instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png") + label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png") + json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json") + + files.append((image_file, instance_file, label_file, json_file)) + assert len(files), "No images found in {}".format(image_dir) + for f in files[0]: + assert PathManager.isfile(f), f + return files + + +def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train". + from_json (bool): whether to read annotations from the raw json file or the png files. + to_polygons (bool): whether to represent the segmentation as polygons + (COCO's format) instead of masks (cityscapes's format). + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + if from_json: + assert to_polygons, ( + "Cityscapes's json annotations are in polygon format. " + "Converting to mask format is not supported now." + ) + files = _get_cityscapes_files(image_dir, gt_dir) + + logger.info("Preprocessing cityscapes annotations ...") + # This is still not fast: all workers will execute duplicate works and will + # take up to 10m on a 8GPU server. + pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4)) + + ret = pool.map( + functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons), + files, + ) + logger.info("Loaded {} images from {}".format(len(ret), image_dir)) + + # Map cityscape ids to contiguous ids + from cityscapesscripts.helpers.labels import labels + + labels = [l for l in labels if l.hasInstances and not l.ignoreInEval] + dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)} + for dict_per_image in ret: + for anno in dict_per_image["annotations"]: + anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]] + return ret + + +def load_cityscapes_semantic(image_dir, gt_dir): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train". + + Returns: + list[dict]: a list of dict, each has "file_name" and + "sem_seg_file_name". + """ + ret = [] + # gt_dir is small and contain many small files. make sense to fetch to local first + gt_dir = PathManager.get_local_path(gt_dir) + for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir): + label_file = label_file.replace("labelIds", "labelTrainIds") + + with PathManager.open(json_file, "r") as f: + jsonobj = json.load(f) + ret.append( + { + "file_name": image_file, + "sem_seg_file_name": label_file, + "height": jsonobj["imgHeight"], + "width": jsonobj["imgWidth"], + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile( + ret[0]["sem_seg_file_name"] + ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa + return ret + + +def _cityscapes_files_to_dict(files, from_json, to_polygons): + """ + Parse cityscapes annotation files to a instance segmentation dataset dict. + + Args: + files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file) + from_json (bool): whether to read annotations from the raw json file or the png files. + to_polygons (bool): whether to represent the segmentation as polygons + (COCO's format) instead of masks (cityscapes's format). + + Returns: + A dict in Detectron2 Dataset format. + """ + from cityscapesscripts.helpers.labels import id2label, name2label + + image_file, instance_id_file, _, json_file = files + + annos = [] + + if from_json: + from shapely.geometry import MultiPolygon, Polygon + + with PathManager.open(json_file, "r") as f: + jsonobj = json.load(f) + ret = { + "file_name": image_file, + "image_id": os.path.basename(image_file), + "height": jsonobj["imgHeight"], + "width": jsonobj["imgWidth"], + } + + # `polygons_union` contains the union of all valid polygons. + polygons_union = Polygon() + + # CityscapesScripts draw the polygons in sequential order + # and each polygon *overwrites* existing ones. See + # (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa + # We use reverse order, and each polygon *avoids* early ones. + # This will resolve the ploygon overlaps in the same way as CityscapesScripts. + for obj in jsonobj["objects"][::-1]: + if "deleted" in obj: # cityscapes data format specific + continue + label_name = obj["label"] + + try: + label = name2label[label_name] + except KeyError: + if label_name.endswith("group"): # crowd area + label = name2label[label_name[: -len("group")]] + else: + raise + if label.id < 0: # cityscapes data format + continue + + # Cityscapes's raw annotations uses integer coordinates + # Therefore +0.5 here + poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5 + # CityscapesScript uses PIL.ImageDraw.polygon to rasterize + # polygons for evaluation. This function operates in integer space + # and draws each pixel whose center falls into the polygon. + # Therefore it draws a polygon which is 0.5 "fatter" in expectation. + # We therefore dilate the input polygon by 0.5 as our input. + poly = Polygon(poly_coord).buffer(0.5, resolution=4) + + if not label.hasInstances or label.ignoreInEval: + # even if we won't store the polygon it still contributes to overlaps resolution + polygons_union = polygons_union.union(poly) + continue + + # Take non-overlapping part of the polygon + poly_wo_overlaps = poly.difference(polygons_union) + if poly_wo_overlaps.is_empty: + continue + polygons_union = polygons_union.union(poly) + + anno = {} + anno["iscrowd"] = label_name.endswith("group") + anno["category_id"] = label.id + + if isinstance(poly_wo_overlaps, Polygon): + poly_list = [poly_wo_overlaps] + elif isinstance(poly_wo_overlaps, MultiPolygon): + poly_list = poly_wo_overlaps.geoms + else: + raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps)) + + poly_coord = [] + for poly_el in poly_list: + # COCO API can work only with exterior boundaries now, hence we store only them. + # TODO: store both exterior and interior boundaries once other parts of the + # codebase support holes in polygons. + poly_coord.append(list(chain(*poly_el.exterior.coords))) + anno["segmentation"] = poly_coord + (xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds + + anno["bbox"] = (xmin, ymin, xmax, ymax) + anno["bbox_mode"] = BoxMode.XYXY_ABS + + annos.append(anno) + else: + # See also the official annotation parsing scripts at + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa + with PathManager.open(instance_id_file, "rb") as f: + inst_image = np.asarray(Image.open(f), order="F") + # ids < 24 are stuff labels (filtering them first is about 5% faster) + flattened_ids = np.unique(inst_image[inst_image >= 24]) + + ret = { + "file_name": image_file, + "image_id": os.path.basename(image_file), + "height": inst_image.shape[0], + "width": inst_image.shape[1], + } + + for instance_id in flattened_ids: + # For non-crowd annotations, instance_id // 1000 is the label_id + # Crowd annotations have <1000 instance ids + label_id = instance_id // 1000 if instance_id >= 1000 else instance_id + label = id2label[label_id] + if not label.hasInstances or label.ignoreInEval: + continue + + anno = {} + anno["iscrowd"] = instance_id < 1000 + anno["category_id"] = label.id + + mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F") + + inds = np.nonzero(mask) + ymin, ymax = inds[0].min(), inds[0].max() + xmin, xmax = inds[1].min(), inds[1].max() + anno["bbox"] = (xmin, ymin, xmax, ymax) + if xmax <= xmin or ymax <= ymin: + continue + anno["bbox_mode"] = BoxMode.XYXY_ABS + if to_polygons: + # This conversion comes from D4809743 and D5171122, + # when Mask-RCNN was first developed. + contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[ + -2 + ] + polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3] + # opencv's can produce invalid polygons + if len(polygons) == 0: + continue + anno["segmentation"] = polygons + else: + anno["segmentation"] = mask_util.encode(mask[:, :, None])[0] + annos.append(anno) + ret["annotations"] = annos + return ret + + +if __name__ == "__main__": + """ + Test the cityscapes dataset loader. + + Usage: + python -m detectron2.data.datasets.cityscapes \ + cityscapes/leftImg8bit/train cityscapes/gtFine/train + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("image_dir") + parser.add_argument("gt_dir") + parser.add_argument("--type", choices=["instance", "semantic"], default="instance") + args = parser.parse_args() + from detectron2.data.catalog import Metadata + from detectron2.utils.visualizer import Visualizer + from cityscapesscripts.helpers.labels import labels + + logger = setup_logger(name=__name__) + + dirname = "cityscapes-data-vis" + os.makedirs(dirname, exist_ok=True) + + if args.type == "instance": + dicts = load_cityscapes_instances( + args.image_dir, args.gt_dir, from_json=True, to_polygons=True + ) + logger.info("Done loading {} samples.".format(len(dicts))) + + thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval] + meta = Metadata().set(thing_classes=thing_classes) + + else: + dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir) + logger.info("Done loading {} samples.".format(len(dicts))) + + stuff_classes = [k.name for k in labels if k.trainId != 255] + stuff_colors = [k.color for k in labels if k.trainId != 255] + meta = Metadata().set(stuff_classes=stuff_classes, stuff_colors=stuff_colors) + + for d in dicts: + img = np.array(Image.open(PathManager.open(d["file_name"], "rb"))) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + # cv2.imshow("a", vis.get_image()[:, :, ::-1]) + # cv2.waitKey() + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/data_processing/detectron2/detectron2/data/datasets/cityscapes_panoptic.py b/data_processing/detectron2/detectron2/data/datasets/cityscapes_panoptic.py new file mode 100644 index 0000000..48c136f --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/cityscapes_panoptic.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES +from detectron2.utils.file_io import PathManager + +""" +This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog. +""" + + +logger = logging.getLogger(__name__) + + +def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info): + files = [] + # scan through the directory + cities = PathManager.ls(image_dir) + logger.info(f"{len(cities)} cities found in '{image_dir}'.") + image_dict = {} + for city in cities: + city_img_dir = os.path.join(image_dir, city) + for basename in PathManager.ls(city_img_dir): + image_file = os.path.join(city_img_dir, basename) + + suffix = "_leftImg8bit.png" + assert basename.endswith(suffix), basename + basename = os.path.basename(basename)[: -len(suffix)] + + image_dict[basename] = image_file + + for ann in json_info["annotations"]: + image_file = image_dict.get(ann["image_id"], None) + assert image_file is not None, "No image {} found for annotation {}".format( + ann["image_id"], ann["file_name"] + ) + label_file = os.path.join(gt_dir, ann["file_name"]) + segments_info = ann["segments_info"] + + files.append((image_file, label_file, segments_info)) + + assert len(files), "No images found in {}".format(image_dir) + assert PathManager.isfile(files[0][0]), files[0][0] + assert PathManager.isfile(files[0][1]), files[0][1] + return files + + +def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., + "~/cityscapes/gtFine/cityscapes_panoptic_train". + gt_json (str): path to the json file. e.g., + "~/cityscapes/gtFine/cityscapes_panoptic_train.json". + meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id" + and "stuff_dataset_id_to_contiguous_id" to map category ids to + contiguous ids for training. + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + + def _convert_category_id(segment_info, meta): + if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: + segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + else: + segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + return segment_info + + assert os.path.exists( + gt_json + ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa + with open(gt_json) as f: + json_info = json.load(f) + files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info) + ret = [] + for image_file, label_file, segments_info in files: + sem_label_file = ( + image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png" + ) + segments_info = [_convert_category_id(x, meta) for x in segments_info] + ret.append( + { + "file_name": image_file, + "image_id": "_".join( + os.path.splitext(os.path.basename(image_file))[0].split("_")[:3] + ), + "sem_seg_file_name": sem_label_file, + "pan_seg_file_name": label_file, + "segments_info": segments_info, + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile( + ret[0]["sem_seg_file_name"] + ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa + assert PathManager.isfile( + ret[0]["pan_seg_file_name"] + ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa + return ret + + +_RAW_CITYSCAPES_PANOPTIC_SPLITS = { + "cityscapes_fine_panoptic_train": ( + "cityscapes/leftImg8bit/train", + "cityscapes/gtFine/cityscapes_panoptic_train", + "cityscapes/gtFine/cityscapes_panoptic_train.json", + ), + "cityscapes_fine_panoptic_val": ( + "cityscapes/leftImg8bit/val", + "cityscapes/gtFine/cityscapes_panoptic_val", + "cityscapes/gtFine/cityscapes_panoptic_val.json", + ), + # "cityscapes_fine_panoptic_test": not supported yet +} + + +def register_all_cityscapes_panoptic(root): + meta = {} + # The following metadata maps contiguous id from [0, #thing categories + + # #stuff categories) to their names and colors. We have to replica of the + # same name and color under "thing_*" and "stuff_*" because the current + # visualization function in D2 handles thing and class classes differently + # due to some heuristic used in Panoptic FPN. We keep the same naming to + # enable reusing existing visualization functions. + thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES] + thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES] + stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES] + stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES] + + meta["thing_classes"] = thing_classes + meta["thing_colors"] = thing_colors + meta["stuff_classes"] = stuff_classes + meta["stuff_colors"] = stuff_colors + + # There are three types of ids in cityscapes panoptic segmentation: + # (1) category id: like semantic segmentation, it is the class id for each + # pixel. Since there are some classes not used in evaluation, the category + # id is not always contiguous and thus we have two set of category ids: + # - original category id: category id in the original dataset, mainly + # used for evaluation. + # - contiguous category id: [0, #classes), in order to train the classifier + # (2) instance id: this id is used to differentiate different instances from + # the same category. For "stuff" classes, the instance id is always 0; for + # "thing" classes, the instance id starts from 1 and 0 is reserved for + # ignored instances (e.g. crowd annotation). + # (3) panoptic id: this is the compact id that encode both category and + # instance id by: category_id * 1000 + instance_id. + thing_dataset_id_to_contiguous_id = {} + stuff_dataset_id_to_contiguous_id = {} + + for k in CITYSCAPES_CATEGORIES: + if k["isthing"] == 1: + thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"] + else: + stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"] + + meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id + meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id + + for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items(): + image_dir = os.path.join(root, image_dir) + gt_dir = os.path.join(root, gt_dir) + gt_json = os.path.join(root, gt_json) + + DatasetCatalog.register( + key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta) + ) + MetadataCatalog.get(key).set( + panoptic_root=gt_dir, + image_root=image_dir, + panoptic_json=gt_json, + gt_dir=gt_dir.replace("cityscapes_panoptic_", ""), + evaluator_type="cityscapes_panoptic_seg", + ignore_label=255, + label_divisor=1000, + **meta, + ) diff --git a/data_processing/detectron2/detectron2/data/datasets/coco.py b/data_processing/detectron2/detectron2/data/datasets/coco.py new file mode 100644 index 0000000..ed4f7cc --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/coco.py @@ -0,0 +1,539 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import datetime +import io +import json +import logging +import numpy as np +import os +import shutil +import pycocotools.mask as mask_util +from fvcore.common.timer import Timer +from iopath.common.file_io import file_lock +from PIL import Image + +from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes +from detectron2.utils.file_io import PathManager + +from .. import DatasetCatalog, MetadataCatalog + +""" +This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format". +""" + + +logger = logging.getLogger(__name__) + +__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"] + + +def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): + """ + Load a json file with COCO's instances annotation format. + Currently supports instance detection, instance segmentation, + and person keypoints annotations. + + Args: + json_file (str): full path to the json file in COCO instances annotation format. + image_root (str or path-like): the directory where the images in this json file exists. + dataset_name (str or None): the name of the dataset (e.g., coco_2017_train). + When provided, this function will also do the following: + + * Put "thing_classes" into the metadata associated with this dataset. + * Map the category ids into a contiguous range (needed by standard dataset format), + and add "thing_dataset_id_to_contiguous_id" to the metadata associated + with this dataset. + + This option should usually be provided, unless users need to load + the original json content and apply more processing manually. + extra_annotation_keys (list[str]): list of per-annotation keys that should also be + loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints", + "category_id", "segmentation"). The values for these keys will be returned as-is. + For example, the densepose annotations are loaded in this way. + + Returns: + list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See + `Using Custom Datasets `_ ) when `dataset_name` is not None. + If `dataset_name` is None, the returned `category_ids` may be + incontiguous and may not conform to the Detectron2 standard format. + + Notes: + 1. This function does not read the image files. + The results do not have the "image" field. + """ + from pycocotools.coco import COCO + + timer = Timer() + json_file = PathManager.get_local_path(json_file) + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCO(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + + id_map = None + if dataset_name is not None: + meta = MetadataCatalog.get(dataset_name) + cat_ids = sorted(coco_api.getCatIds()) + cats = coco_api.loadCats(cat_ids) + # The categories in a custom json file may not be sorted. + thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] + meta.thing_classes = thing_classes + + # In COCO, certain category ids are artificially removed, + # and by convention they are always ignored. + # We deal with COCO's id issue and translate + # the category ids to contiguous ids in [0, 80). + + # It works by looking at the "categories" field in the json, therefore + # if users' own json also have incontiguous ids, we'll + # apply this mapping as well but print a warning. + if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): + if "coco" not in dataset_name: + logger.warning( + """ +Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. +""" + ) + id_map = {v: i for i, v in enumerate(cat_ids)} + meta.thing_dataset_id_to_contiguous_id = id_map + + # sort indices for reproducible results + img_ids = sorted(coco_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = coco_api.loadImgs(img_ids) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. Example of anns[0]: + # [{'segmentation': [[192.81, + # 247.09, + # ... + # 219.03, + # 249.06]], + # 'area': 1035.749, + # 'iscrowd': 0, + # 'image_id': 1268, + # 'bbox': [192.81, 224.8, 74.73, 33.43], + # 'category_id': 16, + # 'id': 42986}, + # ...] + anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] + total_num_valid_anns = sum([len(x) for x in anns]) + total_num_anns = len(coco_api.anns) + if total_num_valid_anns < total_num_anns: + logger.warning( + f"{json_file} contains {total_num_anns} annotations, but only " + f"{total_num_valid_anns} of them match to images in the file." + ) + + if "minival" not in json_file: + # The popular valminusminival & minival annotations for COCO2014 contain this bug. + # However the ratio of buggy annotations there is tiny and does not affect accuracy. + # Therefore we explicitly white-list them. + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( + json_file + ) + + imgs_anns = list(zip(imgs, anns)) + logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file)) + + dataset_dicts = [] + + ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or []) + + num_instances_without_valid_segmentation = 0 + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + record["file_name"] = os.path.join(image_root, img_dict["file_name"]) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + image_id = record["image_id"] = img_dict["id"] + + objs = [] + for anno in anno_dict_list: + # Check that the image_id in this annotation is the same as + # the image_id we're looking at. + # This fails only when the data parsing logic or the annotation file is buggy. + + # The original COCO valminusminival2014 & minival2014 annotation files + # actually contains bugs that, together with certain ways of using COCO API, + # can trigger this assertion. + assert anno["image_id"] == image_id + + assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.' + + obj = {key: anno[key] for key in ann_keys if key in anno} + if "bbox" in obj and len(obj["bbox"]) == 0: + raise ValueError( + f"One annotation of image {image_id} contains empty 'bbox' value! " + "This json does not have valid COCO format." + ) + + segm = anno.get("segmentation", None) + if segm: # either list[list[float]] or dict(RLE) + if isinstance(segm, dict): + if isinstance(segm["counts"], list): + # convert to compressed RLE + segm = mask_util.frPyObjects(segm, *segm["size"]) + else: + # filter out invalid polygons (< 3 points) + segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + if len(segm) == 0: + num_instances_without_valid_segmentation += 1 + continue # ignore this instance + obj["segmentation"] = segm + + keypts = anno.get("keypoints", None) + if keypts: # list[int] + for idx, v in enumerate(keypts): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # Therefore we assume the coordinates are "pixel indices" and + # add 0.5 to convert to floating point coordinates. + keypts[idx] = v + 0.5 + obj["keypoints"] = keypts + + obj["bbox_mode"] = BoxMode.XYWH_ABS + if id_map: + annotation_category_id = obj["category_id"] + try: + obj["category_id"] = id_map[annotation_category_id] + except KeyError as e: + raise KeyError( + f"Encountered category_id={annotation_category_id} " + "but this id does not exist in 'categories' of the json file." + ) from e + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + if num_instances_without_valid_segmentation > 0: + logger.warning( + "Filtered out {} instances without valid segmentation. ".format( + num_instances_without_valid_segmentation + ) + + "There might be issues in your dataset generation process. Please " + "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully" + ) + return dataset_dicts + + +def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"): + """ + Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are + treated as ground truth annotations and all files under "image_root" with "image_ext" extension + as input images. Ground truth and input images are matched using file paths relative to + "gt_root" and "image_root" respectively without taking into account file extensions. + This works for COCO as well as some other datasets. + + Args: + gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation + annotations are stored as images with integer values in pixels that represent + corresponding semantic labels. + image_root (str): the directory where the input images are. + gt_ext (str): file extension for ground truth annotations. + image_ext (str): file extension for input images. + + Returns: + list[dict]: + a list of dicts in detectron2 standard format without instance-level + annotation. + + Notes: + 1. This function does not read the image and ground truth files. + The results do not have the "image" and "sem_seg" fields. + """ + + # We match input images with ground truth based on their relative filepaths (without file + # extensions) starting from 'image_root' and 'gt_root' respectively. + def file2id(folder_path, file_path): + # extract relative path starting from `folder_path` + image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path)) + # remove file extension + image_id = os.path.splitext(image_id)[0] + return image_id + + input_files = sorted( + (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)), + key=lambda file_path: file2id(image_root, file_path), + ) + gt_files = sorted( + (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)), + key=lambda file_path: file2id(gt_root, file_path), + ) + + assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root) + + # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images + if len(input_files) != len(gt_files): + logger.warn( + "Directory {} and {} has {} and {} files, respectively.".format( + image_root, gt_root, len(input_files), len(gt_files) + ) + ) + input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files] + gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files] + intersect = list(set(input_basenames) & set(gt_basenames)) + # sort, otherwise each worker may obtain a list[dict] in different order + intersect = sorted(intersect) + logger.warn("Will use their intersection of {} files.".format(len(intersect))) + input_files = [os.path.join(image_root, f + image_ext) for f in intersect] + gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect] + + logger.info( + "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root) + ) + + dataset_dicts = [] + for (img_path, gt_path) in zip(input_files, gt_files): + record = {} + record["file_name"] = img_path + record["sem_seg_file_name"] = gt_path + dataset_dicts.append(record) + + return dataset_dicts + + +def convert_to_coco_dict(dataset_name): + """ + Convert an instance detection/segmentation or keypoint detection dataset + in detectron2's standard format into COCO json format. + + Generic dataset description can be found here: + https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset + + COCO data format description can be found here: + http://cocodataset.org/#format-data + + Args: + dataset_name (str): + name of the source dataset + Must be registered in DatastCatalog and in detectron2's standard format. + Must have corresponding metadata "thing_classes" + Returns: + coco_dict: serializable dict in COCO json format + """ + + dataset_dicts = DatasetCatalog.get(dataset_name) + metadata = MetadataCatalog.get(dataset_name) + + # unmap the category mapping ids for COCO + if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()} + reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa + else: + reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa + + categories = [ + {"id": reverse_id_mapper(id), "name": name} + for id, name in enumerate(metadata.thing_classes) + ] + + logger.info("Converting dataset dicts into COCO format") + coco_images = [] + coco_annotations = [] + + for image_id, image_dict in enumerate(dataset_dicts): + coco_image = { + "id": image_dict.get("image_id", image_id), + "width": int(image_dict["width"]), + "height": int(image_dict["height"]), + "file_name": str(image_dict["file_name"]), + } + coco_images.append(coco_image) + + anns_per_image = image_dict.get("annotations", []) + for annotation in anns_per_image: + # create a new dict with only COCO fields + coco_annotation = {} + + # COCO requirement: XYWH box format for axis-align and XYWHA for rotated + bbox = annotation["bbox"] + if isinstance(bbox, np.ndarray): + if bbox.ndim != 1: + raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.") + bbox = bbox.tolist() + if len(bbox) not in [4, 5]: + raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.") + from_bbox_mode = annotation["bbox_mode"] + to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS + bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode) + + # COCO requirement: instance area + if "segmentation" in annotation: + # Computing areas for instances by counting the pixels + segmentation = annotation["segmentation"] + # TODO: check segmentation type: RLE, BinaryMask or Polygon + if isinstance(segmentation, list): + polygons = PolygonMasks([segmentation]) + area = polygons.area()[0].item() + elif isinstance(segmentation, dict): # RLE + area = mask_util.area(segmentation).item() + else: + raise TypeError(f"Unknown segmentation type {type(segmentation)}!") + else: + # Computing areas using bounding boxes + if to_bbox_mode == BoxMode.XYWH_ABS: + bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS) + area = Boxes([bbox_xy]).area()[0].item() + else: + area = RotatedBoxes([bbox]).area()[0].item() + + if "keypoints" in annotation: + keypoints = annotation["keypoints"] # list[int] + for idx, v in enumerate(keypoints): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # For COCO format consistency we substract 0.5 + # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163 + keypoints[idx] = v - 0.5 + if "num_keypoints" in annotation: + num_keypoints = annotation["num_keypoints"] + else: + num_keypoints = sum(kp > 0 for kp in keypoints[2::3]) + + # COCO requirement: + # linking annotations to images + # "id" field must start with 1 + coco_annotation["id"] = len(coco_annotations) + 1 + coco_annotation["image_id"] = coco_image["id"] + coco_annotation["bbox"] = [round(float(x), 3) for x in bbox] + coco_annotation["area"] = float(area) + coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0)) + coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"])) + + # Add optional fields + if "keypoints" in annotation: + coco_annotation["keypoints"] = keypoints + coco_annotation["num_keypoints"] = num_keypoints + + if "segmentation" in annotation: + seg = coco_annotation["segmentation"] = annotation["segmentation"] + if isinstance(seg, dict): # RLE + counts = seg["counts"] + if not isinstance(counts, str): + # make it json-serializable + seg["counts"] = counts.decode("ascii") + + coco_annotations.append(coco_annotation) + + logger.info( + "Conversion finished, " + f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}" + ) + + info = { + "date_created": str(datetime.datetime.now()), + "description": "Automatically generated COCO json file for Detectron2.", + } + coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None} + if len(coco_annotations) > 0: + coco_dict["annotations"] = coco_annotations + return coco_dict + + +def convert_to_coco_json(dataset_name, output_file, allow_cached=True): + """ + Converts dataset into COCO format and saves it to a json file. + dataset_name must be registered in DatasetCatalog and in detectron2's standard format. + + Args: + dataset_name: + reference from the config file to the catalogs + must be registered in DatasetCatalog and in detectron2's standard format + output_file: path of json file that will be saved to + allow_cached: if json file is already present then skip conversion + """ + + # TODO: The dataset or the conversion script *may* change, + # a checksum would be useful for validating the cached data + + PathManager.mkdirs(os.path.dirname(output_file)) + with file_lock(output_file): + if PathManager.exists(output_file) and allow_cached: + logger.warning( + f"Using previously cached COCO format annotations at '{output_file}'. " + "You need to clear the cache file if your dataset has been modified." + ) + else: + logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)") + coco_dict = convert_to_coco_dict(dataset_name) + + logger.info(f"Caching COCO format annotations at '{output_file}' ...") + tmp_file = output_file + ".tmp" + with PathManager.open(tmp_file, "w") as f: + json.dump(coco_dict, f) + shutil.move(tmp_file, output_file) + + +def register_coco_instances(name, metadata, json_file, image_root): + """ + Register a dataset in COCO's json annotation format for + instance detection, instance segmentation and keypoint detection. + (i.e., Type 1 and 2 in http://cocodataset.org/#format-data. + `instances*.json` and `person_keypoints*.json` in the dataset). + + This is an example of how to register a new dataset. + You can do something similar to this function, to register new datasets. + + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + metadata (dict): extra metadata associated with this dataset. You can + leave it as an empty dict. + json_file (str): path to the json instance annotation file. + image_root (str or path-like): directory which contains all the images. + """ + assert isinstance(name, str), name + assert isinstance(json_file, (str, os.PathLike)), json_file + assert isinstance(image_root, (str, os.PathLike)), image_root + # 1. register a function which returns dicts + DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name)) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata + ) + + +if __name__ == "__main__": + """ + Test the COCO json dataset loader. + + Usage: + python -m detectron2.data.datasets.coco \ + path/to/json path/to/image_root dataset_name + + "dataset_name" can be "coco_2014_minival_100", or other + pre-registered ones + """ + from detectron2.utils.logger import setup_logger + from detectron2.utils.visualizer import Visualizer + import detectron2.data.datasets # noqa # add pre-defined metadata + import sys + + logger = setup_logger(name=__name__) + assert sys.argv[3] in DatasetCatalog.list() + meta = MetadataCatalog.get(sys.argv[3]) + + dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3]) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "coco-data-vis" + os.makedirs(dirname, exist_ok=True) + for d in dicts: + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/data_processing/detectron2/detectron2/data/datasets/coco_panoptic.py b/data_processing/detectron2/detectron2/data/datasets/coco_panoptic.py new file mode 100644 index 0000000..b8dae44 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/coco_panoptic.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import json +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.utils.file_io import PathManager + +from .coco import load_coco_json, load_sem_seg + +__all__ = ["register_coco_panoptic", "register_coco_panoptic_separated"] + + +def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". + gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". + json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + + def _convert_category_id(segment_info, meta): + if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: + segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = True + else: + segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = False + return segment_info + + with PathManager.open(json_file) as f: + json_info = json.load(f) + + ret = [] + for ann in json_info["annotations"]: + image_id = int(ann["image_id"]) + # TODO: currently we assume image and label has the same filename but + # different extension, and images have extension ".jpg" for COCO. Need + # to make image extension a user-provided argument if we extend this + # function to support other COCO-like datasets. + image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg") + label_file = os.path.join(gt_dir, ann["file_name"]) + segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]] + ret.append( + { + "file_name": image_file, + "image_id": image_id, + "pan_seg_file_name": label_file, + "segments_info": segments_info, + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] + assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"] + return ret + + +def register_coco_panoptic( + name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None +): + """ + Register a "standard" version of COCO panoptic segmentation dataset named `name`. + The dictionaries in this registered dataset follows detectron2's standard format. + Hence it's called "standard". + + Args: + name (str): the name that identifies a dataset, + e.g. "coco_2017_train_panoptic" + metadata (dict): extra metadata associated with this dataset. + image_root (str): directory which contains all the images + panoptic_root (str): directory which contains panoptic annotation images in COCO format + panoptic_json (str): path to the json panoptic annotation file in COCO format + sem_seg_root (none): not used, to be consistent with + `register_coco_panoptic_separated`. + instances_json (str): path to the json instance annotation file + """ + panoptic_name = name + DatasetCatalog.register( + panoptic_name, + lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, metadata), + ) + MetadataCatalog.get(panoptic_name).set( + panoptic_root=panoptic_root, + image_root=image_root, + panoptic_json=panoptic_json, + json_file=instances_json, + evaluator_type="coco_panoptic_seg", + ignore_label=255, + label_divisor=1000, + **metadata, + ) + + +def register_coco_panoptic_separated( + name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json +): + """ + Register a "separated" version of COCO panoptic segmentation dataset named `name`. + The annotations in this registered dataset will contain both instance annotations and + semantic annotations, each with its own contiguous ids. Hence it's called "separated". + + It follows the setting used by the PanopticFPN paper: + + 1. The instance annotations directly come from polygons in the COCO + instances annotation task, rather than from the masks in the COCO panoptic annotations. + + The two format have small differences: + Polygons in the instance annotations may have overlaps. + The mask annotations are produced by labeling the overlapped polygons + with depth ordering. + + 2. The semantic annotations are converted from panoptic annotations, where + all "things" are assigned a semantic id of 0. + All semantic categories will therefore have ids in contiguous + range [1, #stuff_categories]. + + This function will also register a pure semantic segmentation dataset + named ``name + '_stuffonly'``. + + Args: + name (str): the name that identifies a dataset, + e.g. "coco_2017_train_panoptic" + metadata (dict): extra metadata associated with this dataset. + image_root (str): directory which contains all the images + panoptic_root (str): directory which contains panoptic annotation images + panoptic_json (str): path to the json panoptic annotation file + sem_seg_root (str): directory which contains all the ground truth segmentation annotations. + instances_json (str): path to the json instance annotation file + """ + panoptic_name = name + "_separated" + DatasetCatalog.register( + panoptic_name, + lambda: merge_to_panoptic( + load_coco_json(instances_json, image_root, panoptic_name), + load_sem_seg(sem_seg_root, image_root), + ), + ) + MetadataCatalog.get(panoptic_name).set( + panoptic_root=panoptic_root, + image_root=image_root, + panoptic_json=panoptic_json, + sem_seg_root=sem_seg_root, + json_file=instances_json, # TODO rename + evaluator_type="coco_panoptic_seg", + ignore_label=255, + **metadata, + ) + + semantic_name = name + "_stuffonly" + DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root)) + MetadataCatalog.get(semantic_name).set( + sem_seg_root=sem_seg_root, + image_root=image_root, + evaluator_type="sem_seg", + ignore_label=255, + **metadata, + ) + + +def merge_to_panoptic(detection_dicts, sem_seg_dicts): + """ + Create dataset dicts for panoptic segmentation, by + merging two dicts using "file_name" field to match their entries. + + Args: + detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation. + sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation. + + Returns: + list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in + both detection_dicts and sem_seg_dicts that correspond to the same image. + The function assumes that the same key in different dicts has the same value. + """ + results = [] + sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts} + assert len(sem_seg_file_to_entry) > 0 + + for det_dict in detection_dicts: + dic = copy.copy(det_dict) + dic.update(sem_seg_file_to_entry[dic["file_name"]]) + results.append(dic) + return results + + +if __name__ == "__main__": + """ + Test the COCO panoptic dataset loader. + + Usage: + python -m detectron2.data.datasets.coco_panoptic \ + path/to/image_root path/to/panoptic_root path/to/panoptic_json dataset_name 10 + + "dataset_name" can be "coco_2017_train_panoptic", or other + pre-registered ones + """ + from detectron2.utils.logger import setup_logger + from detectron2.utils.visualizer import Visualizer + import detectron2.data.datasets # noqa # add pre-defined metadata + import sys + from PIL import Image + import numpy as np + + logger = setup_logger(name=__name__) + assert sys.argv[4] in DatasetCatalog.list() + meta = MetadataCatalog.get(sys.argv[4]) + + dicts = load_coco_panoptic_json(sys.argv[3], sys.argv[1], sys.argv[2], meta.as_dict()) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "coco-data-vis" + os.makedirs(dirname, exist_ok=True) + num_imgs_to_vis = int(sys.argv[5]) + for i, d in enumerate(dicts): + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) + if i + 1 >= num_imgs_to_vis: + break diff --git a/data_processing/detectron2/detectron2/data/datasets/lvis.py b/data_processing/detectron2/detectron2/data/datasets/lvis.py new file mode 100644 index 0000000..576d962 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/lvis.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os +from fvcore.common.timer import Timer + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + +from .builtin_meta import _get_coco_instances_meta +from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES +from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES +from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT + +""" +This file contains functions to parse LVIS-format annotations into dicts in the +"Detectron2 format". +""" + +logger = logging.getLogger(__name__) + +__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"] + + +def register_lvis_instances(name, metadata, json_file, image_root): + """ + Register a dataset in LVIS's json annotation format for instance detection and segmentation. + + Args: + name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train". + metadata (dict): extra metadata associated with this dataset. It can be an empty dict. + json_file (str): path to the json instance annotation file. + image_root (str or path-like): directory which contains all the images. + """ + DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata + ) + + +def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): + """ + Load a json file in LVIS's annotation format. + + Args: + json_file (str): full path to the LVIS json annotation file. + image_root (str): the directory where the images in this json file exists. + dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train"). + If provided, this function will put "thing_classes" into the metadata + associated with this dataset. + extra_annotation_keys (list[str]): list of per-annotation keys that should also be + loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id", + "segmentation"). The values for these keys will be returned as-is. + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + + Notes: + 1. This function does not read the image files. + The results do not have the "image" field. + """ + from lvis import LVIS + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + + if dataset_name is not None: + meta = get_lvis_instances_meta(dataset_name) + MetadataCatalog.get(dataset_name).set(**meta) + + # sort indices for reproducible results + img_ids = sorted(lvis_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = lvis_api.load_imgs(img_ids) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. Example of anns[0]: + # [{'segmentation': [[192.81, + # 247.09, + # ... + # 219.03, + # 249.06]], + # 'area': 1035.749, + # 'image_id': 1268, + # 'bbox': [192.81, 224.8, 74.73, 33.43], + # 'category_id': 16, + # 'id': 42986}, + # ...] + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + # Sanity check that each annotation has a unique id + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format( + json_file + ) + + imgs_anns = list(zip(imgs, anns)) + + logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file)) + + if extra_annotation_keys: + logger.info( + "The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys) + ) + else: + extra_annotation_keys = [] + + def get_file_name(img_root, img_dict): + # Determine the path including the split folder ("train2017", "val2017", "test2017") from + # the coco_url field. Example: + # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg' + split_folder, file_name = img_dict["coco_url"].split("/")[-2:] + return os.path.join(img_root + split_folder, file_name) + + dataset_dicts = [] + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + record["file_name"] = get_file_name(image_root, img_dict) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", []) + record["neg_category_ids"] = img_dict.get("neg_category_ids", []) + image_id = record["image_id"] = img_dict["id"] + + objs = [] + for anno in anno_dict_list: + # Check that the image_id in this annotation is the same as + # the image_id we're looking at. + # This fails only when the data parsing logic or the annotation file is buggy. + assert anno["image_id"] == image_id + obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS} + # LVIS data loader can be used to load COCO dataset categories. In this case `meta` + # variable will have a field with COCO-specific category mapping. + if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta: + obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]] + else: + obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed + segm = anno["segmentation"] # list[list[float]] + # filter out invalid polygons (< 3 points) + valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + assert len(segm) == len( + valid_segm + ), "Annotation contains an invalid polygon with < 3 points" + assert len(segm) > 0 + obj["segmentation"] = segm + for extra_ann_key in extra_annotation_keys: + obj[extra_ann_key] = anno[extra_ann_key] + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + return dataset_dicts + + +def get_lvis_instances_meta(dataset_name): + """ + Load LVIS metadata. + + Args: + dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5"). + + Returns: + dict: LVIS metadata with keys: thing_classes + """ + if "cocofied" in dataset_name: + return _get_coco_instances_meta() + if "v0.5" in dataset_name: + return _get_lvis_instances_meta_v0_5() + elif "v1" in dataset_name: + return _get_lvis_instances_meta_v1() + raise ValueError("No built-in metadata for dataset {}".format(dataset_name)) + + +def _get_lvis_instances_meta_v0_5(): + assert len(LVIS_V0_5_CATEGORIES) == 1230 + cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES] + assert min(cat_ids) == 1 and max(cat_ids) == len( + cat_ids + ), "Category ids are not in [1, #categories], as expected" + # Ensure that the category list is sorted by id + lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"]) + thing_classes = [k["synonyms"][0] for k in lvis_categories] + meta = {"thing_classes": thing_classes} + return meta + + +def _get_lvis_instances_meta_v1(): + assert len(LVIS_V1_CATEGORIES) == 1203 + cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES] + assert min(cat_ids) == 1 and max(cat_ids) == len( + cat_ids + ), "Category ids are not in [1, #categories], as expected" + # Ensure that the category list is sorted by id + lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"]) + thing_classes = [k["synonyms"][0] for k in lvis_categories] + meta = {"thing_classes": thing_classes, "class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT} + return meta + + +if __name__ == "__main__": + """ + Test the LVIS json dataset loader. + + Usage: + python -m detectron2.data.datasets.lvis \ + path/to/json path/to/image_root dataset_name vis_limit + """ + import sys + import numpy as np + from detectron2.utils.logger import setup_logger + from PIL import Image + import detectron2.data.datasets # noqa # add pre-defined metadata + from detectron2.utils.visualizer import Visualizer + + logger = setup_logger(name=__name__) + meta = MetadataCatalog.get(sys.argv[3]) + + dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3]) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "lvis-data-vis" + os.makedirs(dirname, exist_ok=True) + for d in dicts[: int(sys.argv[4])]: + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/data_processing/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py b/data_processing/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py new file mode 100644 index 0000000..d3dab61 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v0.5_val.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["image_count"] +# del x["instance_count"] +# LVIS_CATEGORIES = repr(c) + " # noqa" + +# fmt: off +LVIS_CATEGORIES = [{'frequency': 'r', 'id': 1, 'synset': 'acorn.n.01', 'synonyms': ['acorn'], 'def': 'nut from an oak tree', 'name': 'acorn'}, {'frequency': 'c', 'id': 2, 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'id': 3, 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'id': 4, 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'c', 'id': 5, 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'id': 6, 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'r', 'id': 7, 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'id': 8, 'synset': 'almond.n.02', 'synonyms': ['almond'], 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'id': 9, 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'r', 'id': 10, 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'id': 11, 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'id': 12, 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'id': 13, 'synset': 'apple.n.01', 'synonyms': ['apple'], 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'id': 14, 'synset': 'apple_juice.n.01', 'synonyms': ['apple_juice'], 'def': 'the juice of apples', 'name': 'apple_juice'}, {'frequency': 'r', 'id': 15, 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'id': 16, 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'id': 17, 'synset': 'apron.n.01', 'synonyms': ['apron'], 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'id': 18, 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'c', 'id': 19, 'synset': 'armband.n.02', 'synonyms': ['armband'], 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'id': 20, 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'id': 21, 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'id': 22, 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'id': 23, 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'id': 24, 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'id': 25, 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'id': 26, 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'id': 27, 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'c', 'id': 28, 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'id': 29, 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'id': 30, 'synset': 'awning.n.01', 'synonyms': ['awning'], 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'id': 31, 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'f', 'id': 32, 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'id': 33, 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'id': 34, 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'id': 35, 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'id': 36, 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'id': 37, 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'id': 38, 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'id': 39, 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'id': 40, 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'id': 41, 'synset': 'ball.n.06', 'synonyms': ['ball'], 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'id': 42, 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'id': 43, 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'id': 44, 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'id': 45, 'synset': 'banana.n.02', 'synonyms': ['banana'], 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'r', 'id': 46, 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'id': 47, 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'c', 'id': 48, 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'id': 49, 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'id': 50, 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'id': 51, 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'id': 52, 'synset': 'barge.n.01', 'synonyms': ['barge'], 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'id': 53, 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'id': 54, 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'id': 55, 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'id': 56, 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'id': 57, 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'id': 58, 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'id': 59, 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'id': 60, 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'id': 61, 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'id': 62, 'synset': 'basket.n.03', 'synonyms': ['basketball_hoop'], 'def': 'metal hoop supporting a net through which players try to throw the basketball', 'name': 'basketball_hoop'}, {'frequency': 'c', 'id': 63, 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'id': 64, 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'r', 'id': 65, 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'id': 66, 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'id': 67, 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'id': 68, 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'id': 69, 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'id': 70, 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'id': 71, 'synset': 'battery.n.02', 'synonyms': ['battery'], 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'id': 72, 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'id': 73, 'synset': 'bead.n.01', 'synonyms': ['bead'], 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'r', 'id': 74, 'synset': 'beaker.n.01', 'synonyms': ['beaker'], 'def': 'a flatbottomed jar made of glass or plastic; used for chemistry', 'name': 'beaker'}, {'frequency': 'c', 'id': 75, 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'id': 76, 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'id': 77, 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'id': 78, 'synset': 'bear.n.01', 'synonyms': ['bear'], 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'id': 79, 'synset': 'bed.n.01', 'synonyms': ['bed'], 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'c', 'id': 80, 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'id': 81, 'synset': 'beef.n.01', 'synonyms': ['cow'], 'def': 'cattle that are reared for their meat', 'name': 'cow'}, {'frequency': 'c', 'id': 82, 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'id': 83, 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'id': 84, 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'id': 85, 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'id': 86, 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'id': 87, 'synset': 'bell.n.01', 'synonyms': ['bell'], 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'id': 88, 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'id': 89, 'synset': 'belt.n.02', 'synonyms': ['belt'], 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'id': 90, 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'id': 91, 'synset': 'bench.n.01', 'synonyms': ['bench'], 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'id': 92, 'synset': 'beret.n.01', 'synonyms': ['beret'], 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'id': 93, 'synset': 'bib.n.02', 'synonyms': ['bib'], 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'id': 94, 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'id': 95, 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'id': 96, 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'c', 'id': 97, 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'id': 98, 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'id': 99, 'synset': 'bird.n.01', 'synonyms': ['bird'], 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'r', 'id': 100, 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'r', 'id': 101, 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'id': 102, 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'id': 103, 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'id': 104, 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'id': 105, 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'id': 106, 'synset': 'biscuit.n.01', 'synonyms': ['biscuit_(bread)'], 'def': 'small round bread leavened with baking-powder or soda', 'name': 'biscuit_(bread)'}, {'frequency': 'r', 'id': 107, 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'id': 108, 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'id': 109, 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'id': 110, 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'id': 111, 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'id': 112, 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'id': 113, 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'c', 'id': 114, 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'c', 'id': 115, 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'id': 116, 'synset': 'boar.n.02', 'synonyms': ['boar'], 'def': 'an uncastrated male hog', 'name': 'boar'}, {'frequency': 'r', 'id': 117, 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'id': 118, 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'c', 'id': 119, 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'r', 'id': 120, 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'id': 121, 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'id': 122, 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'id': 123, 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'id': 124, 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'id': 125, 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'id': 126, 'synset': 'book.n.01', 'synonyms': ['book'], 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'r', 'id': 127, 'synset': 'book_bag.n.01', 'synonyms': ['book_bag'], 'def': 'a bag in which students carry their books', 'name': 'book_bag'}, {'frequency': 'c', 'id': 128, 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'id': 129, 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'id': 130, 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'id': 131, 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'id': 132, 'synset': 'boot.n.01', 'synonyms': ['boot'], 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'id': 133, 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'id': 134, 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'id': 135, 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'id': 136, 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'id': 137, 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'id': 138, 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'id': 139, 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'id': 140, 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'id': 141, 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'id': 142, 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'r', 'id': 143, 'synset': 'bowling_pin.n.01', 'synonyms': ['bowling_pin'], 'def': 'a club-shaped wooden object used in bowling', 'name': 'bowling_pin'}, {'frequency': 'r', 'id': 144, 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'id': 145, 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'id': 146, 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'id': 147, 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'id': 148, 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'id': 149, 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'r', 'id': 150, 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'c', 'id': 151, 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'id': 152, 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'c', 'id': 153, 'synset': 'bristle_brush.n.01', 'synonyms': ['bristle_brush'], 'def': 'a brush that is made with the short stiff hairs of an animal or plant', 'name': 'bristle_brush'}, {'frequency': 'f', 'id': 154, 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'id': 155, 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'id': 156, 'synset': 'broom.n.01', 'synonyms': ['broom'], 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'id': 157, 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'id': 158, 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'id': 159, 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'id': 160, 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'id': 161, 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'id': 162, 'synset': 'bull.n.11', 'synonyms': ['bull'], 'def': 'mature male cow', 'name': 'bull'}, {'frequency': 'r', 'id': 163, 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'id': 164, 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'id': 165, 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'id': 166, 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'id': 167, 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'id': 168, 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'r', 'id': 169, 'synset': 'bully_beef.n.01', 'synonyms': ['corned_beef', 'corn_beef'], 'def': 'beef cured or pickled in brine', 'name': 'corned_beef'}, {'frequency': 'f', 'id': 170, 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'id': 171, 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'id': 172, 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'id': 173, 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'id': 174, 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'id': 175, 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'c', 'id': 176, 'synset': 'butcher_knife.n.01', 'synonyms': ['butcher_knife'], 'def': 'a large sharp knife for cutting or trimming meat', 'name': 'butcher_knife'}, {'frequency': 'c', 'id': 177, 'synset': 'butter.n.01', 'synonyms': ['butter'], 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'id': 178, 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'id': 179, 'synset': 'button.n.01', 'synonyms': ['button'], 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'id': 180, 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'id': 181, 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'r', 'id': 182, 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'id': 183, 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'id': 184, 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'id': 185, 'synset': 'cake.n.03', 'synonyms': ['cake'], 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'id': 186, 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'id': 187, 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'id': 188, 'synset': 'calf.n.01', 'synonyms': ['calf'], 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'id': 189, 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'id': 190, 'synset': 'camel.n.01', 'synonyms': ['camel'], 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'id': 191, 'synset': 'camera.n.01', 'synonyms': ['camera'], 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'id': 192, 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'id': 193, 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'id': 194, 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'id': 195, 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'r', 'id': 196, 'synset': 'candelabrum.n.01', 'synonyms': ['candelabrum', 'candelabra'], 'def': 'branched candlestick; ornamental; has several lights', 'name': 'candelabrum'}, {'frequency': 'f', 'id': 197, 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'id': 198, 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'id': 199, 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'id': 200, 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'id': 201, 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'id': 202, 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'r', 'id': 203, 'synset': 'cannon.n.02', 'synonyms': ['cannon'], 'def': 'heavy gun fired from a tank', 'name': 'cannon'}, {'frequency': 'c', 'id': 204, 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'r', 'id': 205, 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'id': 206, 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'c', 'id': 207, 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'id': 208, 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'r', 'id': 209, 'synset': 'cape.n.02', 'synonyms': ['cape'], 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'id': 210, 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'id': 211, 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'id': 212, 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'def': 'a wheeled vehicle adapted to the rails of railroad', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'id': 213, 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'id': 214, 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'id': 215, 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'id': 216, 'synset': 'card.n.03', 'synonyms': ['card'], 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'r', 'id': 217, 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'id': 218, 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'id': 219, 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'id': 220, 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'id': 221, 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'c', 'id': 222, 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'id': 223, 'synset': 'cart.n.01', 'synonyms': ['cart'], 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'id': 224, 'synset': 'carton.n.02', 'synonyms': ['carton'], 'def': 'a box made of cardboard; opens by flaps on top', 'name': 'carton'}, {'frequency': 'c', 'id': 225, 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'id': 226, 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'id': 227, 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'id': 228, 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'id': 229, 'synset': 'cat.n.01', 'synonyms': ['cat'], 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'c', 'id': 230, 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'r', 'id': 231, 'synset': 'caviar.n.01', 'synonyms': ['caviar', 'caviare'], 'def': "salted roe of sturgeon or other large fish; usually served as an hors d'oeuvre", 'name': 'caviar'}, {'frequency': 'c', 'id': 232, 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'id': 233, 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'c', 'id': 234, 'synset': 'celery.n.01', 'synonyms': ['celery'], 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'id': 235, 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'id': 236, 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'id': 237, 'synset': 'chair.n.01', 'synonyms': ['chair'], 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'id': 238, 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'id': 239, 'synset': 'champagne.n.01', 'synonyms': ['champagne'], 'def': 'a white sparkling wine produced in Champagne or resembling that produced there', 'name': 'champagne'}, {'frequency': 'f', 'id': 240, 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'id': 241, 'synset': 'chap.n.04', 'synonyms': ['chap'], 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'id': 242, 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'id': 243, 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'id': 244, 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'id': 245, 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'r', 'id': 246, 'synset': 'chest_of_drawers.n.01', 'synonyms': ['chest_of_drawers_(furniture)', 'bureau_(furniture)', 'chest_(furniture)'], 'def': 'furniture with drawers for keeping clothes', 'name': 'chest_of_drawers_(furniture)'}, {'frequency': 'c', 'id': 247, 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'id': 248, 'synset': 'chicken_wire.n.01', 'synonyms': ['chicken_wire'], 'def': 'a galvanized wire network with a hexagonal mesh; used to build fences', 'name': 'chicken_wire'}, {'frequency': 'r', 'id': 249, 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'r', 'id': 250, 'synset': 'chihuahua.n.03', 'synonyms': ['Chihuahua'], 'def': 'an old breed of tiny short-haired dog with protruding eyes from Mexico', 'name': 'Chihuahua'}, {'frequency': 'r', 'id': 251, 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'id': 252, 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'id': 253, 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'id': 254, 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'id': 255, 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'id': 256, 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'id': 257, 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'id': 258, 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'id': 259, 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'id': 260, 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'def': 'necklace that fits tightly around the neck', 'name': 'choker'}, {'frequency': 'f', 'id': 261, 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'c', 'id': 262, 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'id': 263, 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'id': 264, 'synset': 'chute.n.02', 'synonyms': ['slide'], 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'id': 265, 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'id': 266, 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'c', 'id': 267, 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'id': 268, 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'id': 269, 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'id': 270, 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'r', 'id': 271, 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'id': 272, 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'id': 273, 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'id': 274, 'synset': 'clip.n.03', 'synonyms': ['clip'], 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'id': 275, 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'f', 'id': 276, 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'id': 277, 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'id': 278, 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'id': 279, 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'id': 280, 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'id': 281, 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'id': 282, 'synset': 'coat.n.01', 'synonyms': ['coat'], 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'id': 283, 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'r', 'id': 284, 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'id': 285, 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'c', 'id': 286, 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'r', 'id': 287, 'synset': 'coffee_filter.n.01', 'synonyms': ['coffee_filter'], 'def': 'filter (usually of paper) that passes the coffee and retains the coffee grounds', 'name': 'coffee_filter'}, {'frequency': 'f', 'id': 288, 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'id': 289, 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'id': 290, 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'id': 291, 'synset': 'coil.n.05', 'synonyms': ['coil'], 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'id': 292, 'synset': 'coin.n.01', 'synonyms': ['coin'], 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'r', 'id': 293, 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'id': 294, 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'id': 295, 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'id': 296, 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'id': 297, 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'id': 298, 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'f', 'id': 299, 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'r', 'id': 300, 'synset': 'concrete_mixer.n.01', 'synonyms': ['concrete_mixer', 'cement_mixer'], 'def': 'a machine with a large revolving drum in which cement/concrete is mixed', 'name': 'concrete_mixer'}, {'frequency': 'f', 'id': 301, 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'id': 302, 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'id': 303, 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'id': 304, 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'c', 'id': 305, 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'id': 306, 'synset': 'cookie_jar.n.01', 'synonyms': ['cookie_jar', 'cooky_jar'], 'def': 'a jar in which cookies are kept (and sometimes money is hidden)', 'name': 'cookie_jar'}, {'frequency': 'r', 'id': 307, 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'id': 308, 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'c', 'id': 309, 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'id': 310, 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'r', 'id': 311, 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'c', 'id': 312, 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'def': 'ears of corn that can be prepared and served for human food', 'name': 'edible_corn'}, {'frequency': 'r', 'id': 313, 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'id': 314, 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'id': 315, 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'id': 316, 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'r', 'id': 317, 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'r', 'id': 318, 'synset': 'cos.n.02', 'synonyms': ['romaine_lettuce'], 'def': 'lettuce with long dark-green leaves in a loosely packed elongated head', 'name': 'romaine_lettuce'}, {'frequency': 'c', 'id': 319, 'synset': 'costume.n.04', 'synonyms': ['costume'], 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'id': 320, 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'id': 321, 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'r', 'id': 322, 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'id': 323, 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'r', 'id': 324, 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'c', 'id': 325, 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'id': 326, 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'id': 327, 'synset': 'crate.n.01', 'synonyms': ['crate'], 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'r', 'id': 328, 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'id': 329, 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'r', 'id': 330, 'synset': 'credit_card.n.01', 'synonyms': ['credit_card', 'charge_card', 'debit_card'], 'def': 'a card, usually plastic, used to pay for goods and services', 'name': 'credit_card'}, {'frequency': 'c', 'id': 331, 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'id': 332, 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'id': 333, 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'def': 'an earthen jar (made of baked clay)', 'name': 'crock_pot'}, {'frequency': 'f', 'id': 334, 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'id': 335, 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'r', 'id': 336, 'synset': 'crow.n.01', 'synonyms': ['crow'], 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'c', 'id': 337, 'synset': 'crown.n.04', 'synonyms': ['crown'], 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'id': 338, 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'id': 339, 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'id': 340, 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'c', 'id': 341, 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'r', 'id': 342, 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'id': 343, 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'r', 'id': 344, 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'id': 345, 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'id': 346, 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'id': 347, 'synset': 'cup.n.01', 'synonyms': ['cup'], 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'id': 348, 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'def': 'a metal vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'c', 'id': 349, 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'id': 350, 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'id': 351, 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'id': 352, 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'id': 353, 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'id': 354, 'synset': 'custard.n.01', 'synonyms': ['custard'], 'def': 'sweetened mixture of milk and eggs baked or boiled or frozen', 'name': 'custard'}, {'frequency': 'c', 'id': 355, 'synset': 'cutter.n.06', 'synonyms': ['cutting_tool'], 'def': 'a cutting implement; a tool for cutting', 'name': 'cutting_tool'}, {'frequency': 'r', 'id': 356, 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'id': 357, 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'id': 358, 'synset': 'dachshund.n.01', 'synonyms': ['dachshund', 'dachsie', 'badger_dog'], 'def': 'small long-bodied short-legged breed of dog having a short sleek coat and long drooping ears', 'name': 'dachshund'}, {'frequency': 'r', 'id': 359, 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'id': 360, 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'id': 361, 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'id': 362, 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'id': 363, 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'id': 364, 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'id': 365, 'synset': 'desk.n.01', 'synonyms': ['desk'], 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'id': 366, 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'id': 367, 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'id': 368, 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'def': 'a daily written record of (usually personal) experiences and observations', 'name': 'diary'}, {'frequency': 'r', 'id': 369, 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'id': 370, 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'id': 371, 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'id': 372, 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'c', 'id': 373, 'synset': 'dish.n.01', 'synonyms': ['dish'], 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'id': 374, 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'id': 375, 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'def': 'a cloth for washing dishes', 'name': 'dishrag'}, {'frequency': 'c', 'id': 376, 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'id': 377, 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'id': 378, 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid'], 'def': 'a low-sudsing detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'r', 'id': 379, 'synset': 'diskette.n.01', 'synonyms': ['diskette', 'floppy', 'floppy_disk'], 'def': 'a small plastic magnetic disk enclosed in a stiff envelope used to store data', 'name': 'diskette'}, {'frequency': 'c', 'id': 380, 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'c', 'id': 381, 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'id': 382, 'synset': 'dog.n.01', 'synonyms': ['dog'], 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'id': 383, 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'c', 'id': 384, 'synset': 'doll.n.01', 'synonyms': ['doll'], 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'id': 385, 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'id': 386, 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'id': 387, 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'r', 'id': 388, 'synset': 'domino.n.03', 'synonyms': ['eye_mask'], 'def': 'a mask covering the upper part of the face but with holes for the eyes', 'name': 'eye_mask'}, {'frequency': 'r', 'id': 389, 'synset': 'doorbell.n.01', 'synonyms': ['doorbell', 'buzzer'], 'def': 'a button at an outer door that gives a ringing or buzzing signal when pushed', 'name': 'doorbell'}, {'frequency': 'f', 'id': 390, 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'id': 391, 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'id': 392, 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'id': 393, 'synset': 'dove.n.01', 'synonyms': ['dove'], 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'id': 394, 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'id': 395, 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'id': 396, 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'id': 397, 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'id': 398, 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'c', 'id': 399, 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'c', 'id': 400, 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'id': 401, 'synset': 'drill.n.01', 'synonyms': ['drill'], 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'id': 402, 'synset': 'drinking_fountain.n.01', 'synonyms': ['drinking_fountain'], 'def': 'a public fountain to provide a jet of drinking water', 'name': 'drinking_fountain'}, {'frequency': 'r', 'id': 403, 'synset': 'drone.n.04', 'synonyms': ['drone'], 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'id': 404, 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'id': 405, 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'id': 406, 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'id': 407, 'synset': 'duck.n.01', 'synonyms': ['duck'], 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'r', 'id': 408, 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'id': 409, 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'id': 410, 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'def': 'a large cylindrical bag of heavy cloth', 'name': 'duffel_bag'}, {'frequency': 'r', 'id': 411, 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'id': 412, 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'id': 413, 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'r', 'id': 414, 'synset': 'dutch_oven.n.02', 'synonyms': ['Dutch_oven'], 'def': 'iron or earthenware cooking pot; used for stews', 'name': 'Dutch_oven'}, {'frequency': 'c', 'id': 415, 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'id': 416, 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'id': 417, 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'id': 418, 'synset': 'earring.n.01', 'synonyms': ['earring'], 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'id': 419, 'synset': 'easel.n.01', 'synonyms': ['easel'], 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'id': 420, 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'id': 421, 'synset': 'eel.n.01', 'synonyms': ['eel'], 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'id': 422, 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'id': 423, 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'id': 424, 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'id': 425, 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'id': 426, 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'id': 427, 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'id': 428, 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'id': 429, 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'r', 'id': 430, 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'id': 431, 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'id': 432, 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'id': 433, 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'id': 434, 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'id': 435, 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'id': 436, 'synset': 'fan.n.01', 'synonyms': ['fan'], 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'id': 437, 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'id': 438, 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'id': 439, 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'id': 440, 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'r', 'id': 441, 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'id': 442, 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'id': 443, 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'id': 444, 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'id': 445, 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'id': 446, 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'id': 447, 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'c', 'id': 448, 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'c', 'id': 449, 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'id': 450, 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'id': 451, 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'id': 452, 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'c', 'id': 453, 'synset': 'fish.n.01', 'synonyms': ['fish'], 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'r', 'id': 454, 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'id': 455, 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'r', 'id': 456, 'synset': 'fishing_boat.n.01', 'synonyms': ['fishing_boat', 'fishing_vessel'], 'def': 'a vessel for fishing', 'name': 'fishing_boat'}, {'frequency': 'c', 'id': 457, 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'id': 458, 'synset': 'flag.n.01', 'synonyms': ['flag'], 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'id': 459, 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'id': 460, 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'id': 461, 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'r', 'id': 462, 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'id': 463, 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'id': 464, 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'id': 465, 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'id': 466, 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'id': 467, 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'id': 468, 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'r', 'id': 469, 'synset': 'foal.n.01', 'synonyms': ['foal'], 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'id': 470, 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'id': 471, 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'id': 472, 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'id': 473, 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'id': 474, 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'id': 475, 'synset': 'fork.n.01', 'synonyms': ['fork'], 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'r', 'id': 476, 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'r', 'id': 477, 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'r', 'id': 478, 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'id': 479, 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'def': 'anything that freshens', 'name': 'freshener'}, {'frequency': 'f', 'id': 480, 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'id': 481, 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'id': 482, 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'r', 'id': 483, 'synset': 'fruit_salad.n.01', 'synonyms': ['fruit_salad'], 'def': 'salad composed of fruits', 'name': 'fruit_salad'}, {'frequency': 'c', 'id': 484, 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'id': 485, 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'id': 486, 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'c', 'id': 487, 'synset': 'futon.n.01', 'synonyms': ['futon'], 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'id': 488, 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'id': 489, 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'id': 490, 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'id': 491, 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'id': 492, 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'id': 493, 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'id': 494, 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'id': 495, 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'r', 'id': 496, 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'id': 497, 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'id': 498, 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'c', 'id': 499, 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'id': 500, 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'id': 501, 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'id': 502, 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'id': 503, 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'id': 504, 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'id': 505, 'synset': 'globe.n.03', 'synonyms': ['globe'], 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'id': 506, 'synset': 'glove.n.02', 'synonyms': ['glove'], 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'id': 507, 'synset': 'goat.n.01', 'synonyms': ['goat'], 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'id': 508, 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'id': 509, 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'r', 'id': 510, 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'id': 511, 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'id': 512, 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'id': 513, 'synset': 'goose.n.01', 'synonyms': ['goose'], 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'id': 514, 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'id': 515, 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'r', 'id': 516, 'synset': 'gown.n.04', 'synonyms': ['surgical_gown', 'scrubs_(surgical_clothing)'], 'def': 'protective garment worn by surgeons during operations', 'name': 'surgical_gown'}, {'frequency': 'f', 'id': 517, 'synset': 'grape.n.01', 'synonyms': ['grape'], 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'r', 'id': 518, 'synset': 'grasshopper.n.01', 'synonyms': ['grasshopper'], 'def': 'plant-eating insect with hind legs adapted for leaping', 'name': 'grasshopper'}, {'frequency': 'c', 'id': 519, 'synset': 'grater.n.01', 'synonyms': ['grater'], 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'id': 520, 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'id': 521, 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'c', 'id': 522, 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'c', 'id': 523, 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'id': 524, 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'r', 'id': 525, 'synset': 'grillroom.n.01', 'synonyms': ['grillroom', 'grill_(restaurant)'], 'def': 'a restaurant where food is cooked on a grill', 'name': 'grillroom'}, {'frequency': 'r', 'id': 526, 'synset': 'grinder.n.04', 'synonyms': ['grinder_(tool)'], 'def': 'a machine tool that polishes metal', 'name': 'grinder_(tool)'}, {'frequency': 'r', 'id': 527, 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'id': 528, 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'id': 529, 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'r', 'id': 530, 'synset': 'guacamole.n.01', 'synonyms': ['guacamole'], 'def': 'a dip made of mashed avocado mixed with chopped onions and other seasonings', 'name': 'guacamole'}, {'frequency': 'f', 'id': 531, 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'id': 532, 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'id': 533, 'synset': 'gun.n.01', 'synonyms': ['gun'], 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'r', 'id': 534, 'synset': 'hair_spray.n.01', 'synonyms': ['hair_spray'], 'def': 'substance sprayed on the hair to hold it in place', 'name': 'hair_spray'}, {'frequency': 'c', 'id': 535, 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'id': 536, 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'id': 537, 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'f', 'id': 538, 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'id': 539, 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'id': 540, 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'r', 'id': 541, 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'id': 542, 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'r', 'id': 543, 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'c', 'id': 544, 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'id': 545, 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'id': 546, 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'id': 547, 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'id': 548, 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'id': 549, 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'id': 550, 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'id': 551, 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'id': 552, 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'id': 553, 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'id': 554, 'synset': 'hat.n.01', 'synonyms': ['hat'], 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'id': 555, 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'r', 'id': 556, 'synset': 'hatch.n.03', 'synonyms': ['hatch'], 'def': 'a movable barrier covering a hatchway', 'name': 'hatch'}, {'frequency': 'c', 'id': 557, 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'def': 'a garment that covers the head and face', 'name': 'veil'}, {'frequency': 'f', 'id': 558, 'synset': 'headband.n.01', 'synonyms': ['headband'], 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'id': 559, 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'id': 560, 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'id': 561, 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'id': 562, 'synset': 'headset.n.01', 'synonyms': ['headset'], 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'id': 563, 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'r', 'id': 564, 'synset': 'hearing_aid.n.02', 'synonyms': ['hearing_aid'], 'def': 'an acoustic device used to direct sound to the ear of a hearing-impaired person', 'name': 'hearing_aid'}, {'frequency': 'c', 'id': 565, 'synset': 'heart.n.02', 'synonyms': ['heart'], 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'id': 566, 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'id': 567, 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'id': 568, 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'id': 569, 'synset': 'heron.n.02', 'synonyms': ['heron'], 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'id': 570, 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'id': 571, 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'id': 572, 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'id': 573, 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'id': 574, 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'id': 575, 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'id': 576, 'synset': 'honey.n.01', 'synonyms': ['honey'], 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'id': 577, 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'id': 578, 'synset': 'hook.n.05', 'synonyms': ['hook'], 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'f', 'id': 579, 'synset': 'horse.n.01', 'synonyms': ['horse'], 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'id': 580, 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'id': 581, 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'id': 582, 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'id': 583, 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'id': 584, 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'id': 585, 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'r', 'id': 586, 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'id': 587, 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'c', 'id': 588, 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'id': 589, 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'id': 590, 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'id': 591, 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'id': 592, 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'id': 593, 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'r', 'id': 594, 'synset': 'ice_tea.n.01', 'synonyms': ['ice_tea', 'iced_tea'], 'def': 'strong tea served over ice', 'name': 'ice_tea'}, {'frequency': 'c', 'id': 595, 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'id': 596, 'synset': 'incense.n.01', 'synonyms': ['incense'], 'def': 'a substance that produces a fragrant odor when burned', 'name': 'incense'}, {'frequency': 'r', 'id': 597, 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'c', 'id': 598, 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'id': 599, 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'r', 'id': 600, 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'id': 601, 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'r', 'id': 602, 'synset': 'jam.n.01', 'synonyms': ['jam'], 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'id': 603, 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'id': 604, 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'id': 605, 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'id': 606, 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'id': 607, 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'c', 'id': 608, 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'id': 609, 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'r', 'id': 610, 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'id': 611, 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'id': 612, 'synset': 'keg.n.02', 'synonyms': ['keg'], 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'id': 613, 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'id': 614, 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'id': 615, 'synset': 'key.n.01', 'synonyms': ['key'], 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'id': 616, 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'r', 'id': 617, 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'id': 618, 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'id': 619, 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'c', 'id': 620, 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'id': 621, 'synset': 'kite.n.03', 'synonyms': ['kite'], 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'id': 622, 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'id': 623, 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'id': 624, 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'id': 625, 'synset': 'knife.n.01', 'synonyms': ['knife'], 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'id': 626, 'synset': 'knight.n.02', 'synonyms': ['knight_(chess_piece)', 'horse_(chess_piece)'], 'def': 'a chess game piece shaped to resemble the head of a horse', 'name': 'knight_(chess_piece)'}, {'frequency': 'r', 'id': 627, 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'id': 628, 'synset': 'knob.n.02', 'synonyms': ['knob'], 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'id': 629, 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'id': 630, 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'id': 631, 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'id': 632, 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'id': 633, 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'r', 'id': 634, 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'c', 'id': 635, 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'id': 636, 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'id': 637, 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'id': 638, 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'id': 639, 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'id': 640, 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'id': 641, 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'id': 642, 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'id': 643, 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'c', 'id': 644, 'synset': 'latch.n.02', 'synonyms': ['latch'], 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'id': 645, 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'id': 646, 'synset': 'leather.n.01', 'synonyms': ['leather'], 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'id': 647, 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'id': 648, 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'f', 'id': 649, 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'id': 650, 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'id': 651, 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'id': 652, 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'id': 653, 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'id': 654, 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'id': 655, 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'def': 'glass bulb or tube shaped electric device that emits light (DO NOT MARK LAMPS AS A WHOLE)', 'name': 'lightbulb'}, {'frequency': 'r', 'id': 656, 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'c', 'id': 657, 'synset': 'lime.n.06', 'synonyms': ['lime'], 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'id': 658, 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'r', 'id': 659, 'synset': 'linen.n.02', 'synonyms': ['linen_paper'], 'def': 'a high-quality paper made of linen fibers or with a linen finish', 'name': 'linen_paper'}, {'frequency': 'c', 'id': 660, 'synset': 'lion.n.01', 'synonyms': ['lion'], 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'id': 661, 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'c', 'id': 662, 'synset': 'lipstick.n.01', 'synonyms': ['lipstick', 'lip_rouge'], 'def': 'makeup that is used to color the lips', 'name': 'lipstick'}, {'frequency': 'r', 'id': 663, 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'def': 'an alcoholic beverage that is distilled rather than fermented', 'name': 'liquor'}, {'frequency': 'r', 'id': 664, 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'r', 'id': 665, 'synset': 'loafer.n.02', 'synonyms': ['Loafer_(type_of_shoe)'], 'def': 'a low leather step-in shoe', 'name': 'Loafer_(type_of_shoe)'}, {'frequency': 'f', 'id': 666, 'synset': 'log.n.01', 'synonyms': ['log'], 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'id': 667, 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'c', 'id': 668, 'synset': 'lotion.n.01', 'synonyms': ['lotion'], 'def': 'any of various cosmetic preparations that are applied to the skin', 'name': 'lotion'}, {'frequency': 'f', 'id': 669, 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'id': 670, 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'id': 671, 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'id': 672, 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'id': 673, 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'r', 'id': 674, 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'c', 'id': 675, 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'id': 676, 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'id': 677, 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'c', 'id': 678, 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'id': 679, 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'id': 680, 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'c', 'id': 681, 'synset': 'map.n.01', 'synonyms': ['map'], 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'c', 'id': 682, 'synset': 'marker.n.03', 'synonyms': ['marker'], 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'id': 683, 'synset': 'martini.n.01', 'synonyms': ['martini'], 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'id': 684, 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'id': 685, 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'id': 686, 'synset': 'masher.n.02', 'synonyms': ['masher'], 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'id': 687, 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'id': 688, 'synset': 'mast.n.01', 'synonyms': ['mast'], 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'id': 689, 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'id': 690, 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'id': 691, 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'id': 692, 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'id': 693, 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'id': 694, 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'id': 695, 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'r', 'id': 696, 'synset': 'melon.n.01', 'synonyms': ['melon'], 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'id': 697, 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'id': 698, 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'id': 699, 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'id': 700, 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'c', 'id': 701, 'synset': 'milk.n.01', 'synonyms': ['milk'], 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'f', 'id': 702, 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'id': 703, 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'id': 704, 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'id': 705, 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'id': 706, 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'id': 707, 'synset': 'money.n.03', 'synonyms': ['money'], 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'id': 708, 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'id': 709, 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'id': 710, 'synset': 'motor.n.01', 'synonyms': ['motor'], 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'id': 711, 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'id': 712, 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'r', 'id': 713, 'synset': 'motorboat.n.01', 'synonyms': ['motorboat', 'powerboat'], 'def': 'a boat propelled by an internal-combustion engine', 'name': 'motorboat'}, {'frequency': 'f', 'id': 714, 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'id': 715, 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'r', 'id': 716, 'synset': 'mouse.n.01', 'synonyms': ['mouse_(animal_rodent)'], 'def': 'a small rodent with pointed snouts and small ears on elongated bodies with slender usually hairless tails', 'name': 'mouse_(animal_rodent)'}, {'frequency': 'f', 'id': 717, 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'def': 'a computer input device that controls an on-screen pointer', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'id': 718, 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'id': 719, 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'id': 720, 'synset': 'mug.n.04', 'synonyms': ['mug'], 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'id': 721, 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'id': 722, 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'r', 'id': 723, 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'id': 724, 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'r', 'id': 725, 'synset': 'nameplate.n.01', 'synonyms': ['nameplate'], 'def': 'a plate bearing a name', 'name': 'nameplate'}, {'frequency': 'f', 'id': 726, 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'id': 727, 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'id': 728, 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'id': 729, 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'r', 'id': 730, 'synset': 'needle.n.03', 'synonyms': ['needle'], 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'id': 731, 'synset': 'nest.n.01', 'synonyms': ['nest'], 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'r', 'id': 732, 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'id': 733, 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'id': 734, 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'r', 'id': 735, 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'id': 736, 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'id': 737, 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'c', 'id': 738, 'synset': 'nut.n.03', 'synonyms': ['nut'], 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'id': 739, 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'c', 'id': 740, 'synset': 'oar.n.01', 'synonyms': ['oar'], 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'id': 741, 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'id': 742, 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'id': 743, 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'id': 744, 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'id': 745, 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'id': 746, 'synset': 'onion.n.01', 'synonyms': ['onion'], 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'id': 747, 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'id': 748, 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'r', 'id': 749, 'synset': 'oregano.n.01', 'synonyms': ['oregano', 'marjoram'], 'def': 'aromatic Eurasian perennial herb used in cooking and baking', 'name': 'oregano'}, {'frequency': 'c', 'id': 750, 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'c', 'id': 751, 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'def': 'thick cushion used as a seat', 'name': 'ottoman'}, {'frequency': 'c', 'id': 752, 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'id': 753, 'synset': 'owl.n.01', 'synonyms': ['owl'], 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'id': 754, 'synset': 'packet.n.03', 'synonyms': ['packet'], 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'id': 755, 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'id': 756, 'synset': 'pad.n.04', 'synonyms': ['pad'], 'def': 'a flat mass of soft material used for protection, stuffing, or comfort', 'name': 'pad'}, {'frequency': 'c', 'id': 757, 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'id': 758, 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'r', 'id': 759, 'synset': 'paintbox.n.01', 'synonyms': ['paintbox'], 'def': "a box containing a collection of cubes or tubes of artists' paint", 'name': 'paintbox'}, {'frequency': 'c', 'id': 760, 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'id': 761, 'synset': 'painting.n.01', 'synonyms': ['painting'], 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'c', 'id': 762, 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'id': 763, 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'id': 764, 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'id': 765, 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'id': 766, 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'id': 767, 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'id': 768, 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'r', 'id': 769, 'synset': 'paper_clip.n.01', 'synonyms': ['paperclip'], 'def': 'a wire or plastic clip for holding sheets of paper together', 'name': 'paperclip'}, {'frequency': 'f', 'id': 770, 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'id': 771, 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'id': 772, 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'id': 773, 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'id': 774, 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'r', 'id': 775, 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'id': 776, 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'r', 'id': 777, 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'r', 'id': 778, 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'id': 779, 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'id': 780, 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'id': 781, 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'id': 782, 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'r', 'id': 783, 'synset': 'passport.n.02', 'synonyms': ['passport'], 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'id': 784, 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'id': 785, 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'id': 786, 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'id': 787, 'synset': 'peach.n.03', 'synonyms': ['peach'], 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'id': 788, 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'c', 'id': 789, 'synset': 'pear.n.01', 'synonyms': ['pear'], 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'r', 'id': 790, 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'id': 791, 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'id': 792, 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'id': 793, 'synset': 'pen.n.01', 'synonyms': ['pen'], 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'c', 'id': 794, 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'id': 795, 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'id': 796, 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'id': 797, 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'id': 798, 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'id': 799, 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'id': 800, 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'c', 'id': 801, 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'id': 802, 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'id': 803, 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'id': 804, 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'id': 805, 'synset': 'person.n.01', 'synonyms': ['baby', 'child', 'boy', 'girl', 'man', 'woman', 'person', 'human'], 'def': 'a human being', 'name': 'baby'}, {'frequency': 'r', 'id': 806, 'synset': 'pet.n.01', 'synonyms': ['pet'], 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'r', 'id': 807, 'synset': 'petfood.n.01', 'synonyms': ['petfood', 'pet-food'], 'def': 'food prepared for animal pets', 'name': 'petfood'}, {'frequency': 'r', 'id': 808, 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'id': 809, 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'id': 810, 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'c', 'id': 811, 'synset': 'piano.n.01', 'synonyms': ['piano'], 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'id': 812, 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'id': 813, 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'id': 814, 'synset': 'pie.n.01', 'synonyms': ['pie'], 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'id': 815, 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'id': 816, 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'id': 817, 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'id': 818, 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'id': 819, 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'id': 820, 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'id': 821, 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'id': 822, 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'id': 823, 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'id': 824, 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'id': 825, 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'r', 'id': 826, 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'id': 827, 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'id': 828, 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'id': 829, 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'id': 830, 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'id': 831, 'synset': 'plate.n.04', 'synonyms': ['plate'], 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'id': 832, 'synset': 'platter.n.01', 'synonyms': ['platter'], 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'id': 833, 'synset': 'playing_card.n.01', 'synonyms': ['playing_card'], 'def': 'one of a pack of cards that are used to play card games', 'name': 'playing_card'}, {'frequency': 'r', 'id': 834, 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'id': 835, 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'id': 836, 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'id': 837, 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'id': 838, 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'id': 839, 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'id': 840, 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'r', 'id': 841, 'synset': 'police_van.n.01', 'synonyms': ['police_van', 'police_wagon', 'paddy_wagon', 'patrol_wagon'], 'def': 'van used by police to transport prisoners', 'name': 'police_van'}, {'frequency': 'f', 'id': 842, 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'id': 843, 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'id': 844, 'synset': 'pony.n.05', 'synonyms': ['pony'], 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'id': 845, 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'id': 846, 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'r', 'id': 847, 'synset': 'portrait.n.02', 'synonyms': ['portrait', 'portrayal'], 'def': 'any likeness of a person, in any medium', 'name': 'portrait'}, {'frequency': 'c', 'id': 848, 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'id': 849, 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'id': 850, 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'id': 851, 'synset': 'pot.n.01', 'synonyms': ['pot'], 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'id': 852, 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'id': 853, 'synset': 'potato.n.01', 'synonyms': ['potato'], 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'id': 854, 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'id': 855, 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'id': 856, 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'r', 'id': 857, 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'id': 858, 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'f', 'id': 859, 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'id': 860, 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'id': 861, 'synset': 'projector.n.02', 'synonyms': ['projector'], 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'id': 862, 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'id': 863, 'synset': 'prune.n.01', 'synonyms': ['prune'], 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'id': 864, 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'id': 865, 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'id': 866, 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'id': 867, 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'id': 868, 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'id': 869, 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'id': 870, 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'r', 'id': 871, 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'id': 872, 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'id': 873, 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'id': 874, 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'id': 875, 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'id': 876, 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'id': 877, 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'id': 878, 'synset': 'radar.n.01', 'synonyms': ['radar'], 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'c', 'id': 879, 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'id': 880, 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'id': 881, 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'id': 882, 'synset': 'raft.n.01', 'synonyms': ['raft'], 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'id': 883, 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'id': 884, 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'id': 885, 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'id': 886, 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'id': 887, 'synset': 'rat.n.01', 'synonyms': ['rat'], 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'id': 888, 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'id': 889, 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'id': 890, 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'def': 'car mirror that reflects the view out of the rear window', 'name': 'rearview_mirror'}, {'frequency': 'c', 'id': 891, 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'id': 892, 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'r', 'id': 893, 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'r', 'id': 894, 'synset': 'red_cabbage.n.02', 'synonyms': ['red_cabbage'], 'def': 'compact head of purplish-red leaves', 'name': 'red_cabbage'}, {'frequency': 'f', 'id': 895, 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'id': 896, 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'id': 897, 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'id': 898, 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'r', 'id': 899, 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'id': 900, 'synset': 'ring.n.08', 'synonyms': ['ring'], 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'id': 901, 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'id': 902, 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'id': 903, 'synset': 'robe.n.01', 'synonyms': ['robe'], 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'id': 904, 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'id': 905, 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'id': 906, 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'id': 907, 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'id': 908, 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'id': 909, 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'id': 910, 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'id': 911, 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'id': 912, 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'id': 913, 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'id': 914, 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'id': 915, 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'id': 916, 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'c', 'id': 917, 'synset': 'sail.n.01', 'synonyms': ['sail'], 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'c', 'id': 918, 'synset': 'salad.n.01', 'synonyms': ['salad'], 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'id': 919, 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'r', 'id': 920, 'synset': 'salami.n.01', 'synonyms': ['salami'], 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'r', 'id': 921, 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'id': 922, 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'r', 'id': 923, 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'id': 924, 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'id': 925, 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'id': 926, 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'id': 927, 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'id': 928, 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'id': 929, 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'id': 930, 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'id': 931, 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'id': 932, 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'id': 933, 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'id': 934, 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'id': 935, 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'id': 936, 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'id': 937, 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'c', 'id': 938, 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'c', 'id': 939, 'synset': 'scrambled_eggs.n.01', 'synonyms': ['scrambled_eggs'], 'def': 'eggs beaten and cooked to a soft firm consistency while stirring', 'name': 'scrambled_eggs'}, {'frequency': 'r', 'id': 940, 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'r', 'id': 941, 'synset': 'scratcher.n.03', 'synonyms': ['scratcher'], 'def': 'a device used for scratching', 'name': 'scratcher'}, {'frequency': 'c', 'id': 942, 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'c', 'id': 943, 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'id': 944, 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'r', 'id': 945, 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'r', 'id': 946, 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'id': 947, 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'id': 948, 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'r', 'id': 949, 'synset': 'seedling.n.01', 'synonyms': ['seedling'], 'def': 'young plant or tree grown from a seed', 'name': 'seedling'}, {'frequency': 'c', 'id': 950, 'synset': 'serving_dish.n.01', 'synonyms': ['serving_dish'], 'def': 'a dish used for serving food', 'name': 'serving_dish'}, {'frequency': 'r', 'id': 951, 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'r', 'id': 952, 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'id': 953, 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'r', 'id': 954, 'synset': 'shark.n.01', 'synonyms': ['shark'], 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'id': 955, 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'id': 956, 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'id': 957, 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'id': 958, 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'id': 959, 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'id': 960, 'synset': 'shears.n.01', 'synonyms': ['shears'], 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'id': 961, 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'id': 962, 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'id': 963, 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'r', 'id': 964, 'synset': 'shield.n.02', 'synonyms': ['shield'], 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'id': 965, 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'id': 966, 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'c', 'id': 967, 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'id': 968, 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'id': 969, 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'id': 970, 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'c', 'id': 971, 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'id': 972, 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'id': 973, 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'f', 'id': 974, 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'id': 975, 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'r', 'id': 976, 'synset': 'sieve.n.01', 'synonyms': ['sieve', 'screen_(sieve)'], 'def': 'a strainer for separating lumps from powdered material or grading particles', 'name': 'sieve'}, {'frequency': 'f', 'id': 977, 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'id': 978, 'synset': 'silo.n.01', 'synonyms': ['silo'], 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'id': 979, 'synset': 'sink.n.01', 'synonyms': ['sink'], 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'id': 980, 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'id': 981, 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'id': 982, 'synset': 'ski.n.01', 'synonyms': ['ski'], 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'id': 983, 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'id': 984, 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'id': 985, 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'id': 986, 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'c', 'id': 987, 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'id': 988, 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'id': 989, 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'id': 990, 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'id': 991, 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'id': 992, 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'id': 993, 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'id': 994, 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'id': 995, 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'id': 996, 'synset': 'soap.n.01', 'synonyms': ['soap'], 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'id': 997, 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'id': 998, 'synset': 'sock.n.01', 'synonyms': ['sock'], 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'r', 'id': 999, 'synset': 'soda_fountain.n.02', 'synonyms': ['soda_fountain'], 'def': 'an apparatus for dispensing soda water', 'name': 'soda_fountain'}, {'frequency': 'r', 'id': 1000, 'synset': 'soda_water.n.01', 'synonyms': ['carbonated_water', 'club_soda', 'seltzer', 'sparkling_water'], 'def': 'effervescent beverage artificially charged with carbon dioxide', 'name': 'carbonated_water'}, {'frequency': 'f', 'id': 1001, 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'id': 1002, 'synset': 'softball.n.01', 'synonyms': ['softball'], 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'id': 1003, 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'id': 1004, 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'c', 'id': 1005, 'synset': 'soup.n.01', 'synonyms': ['soup'], 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'id': 1006, 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'id': 1007, 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'id': 1008, 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'id': 1009, 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'id': 1010, 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'id': 1011, 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'id': 1012, 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'id': 1013, 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'id': 1014, 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'id': 1015, 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'r', 'id': 1016, 'synset': 'spider.n.01', 'synonyms': ['spider'], 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'c', 'id': 1017, 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'id': 1018, 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'id': 1019, 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'id': 1020, 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'id': 1021, 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'c', 'id': 1022, 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'r', 'id': 1023, 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'id': 1024, 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'id': 1025, 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'id': 1026, 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'r', 'id': 1027, 'synset': 'steamer.n.02', 'synonyms': ['steamer_(kitchen_appliance)'], 'def': 'a cooking utensil that can be used to cook food by steaming it', 'name': 'steamer_(kitchen_appliance)'}, {'frequency': 'f', 'id': 1028, 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'id': 1029, 'synset': 'stencil.n.01', 'synonyms': ['stencil'], 'def': 'a sheet of material (metal, plastic, etc.) that has been perforated with a pattern; ink or paint can pass through the perforations to create the printed pattern on the surface below', 'name': 'stencil'}, {'frequency': 'r', 'id': 1030, 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'id': 1031, 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'id': 1032, 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'id': 1033, 'synset': 'stew.n.02', 'synonyms': ['stew'], 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'id': 1034, 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'id': 1035, 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'c', 'id': 1036, 'synset': 'stocking.n.01', 'synonyms': ['stockings_(leg_wear)'], 'def': 'close-fitting hosiery to cover the foot and leg; come in matched pairs', 'name': 'stockings_(leg_wear)'}, {'frequency': 'f', 'id': 1037, 'synset': 'stool.n.01', 'synonyms': ['stool'], 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'id': 1038, 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'id': 1039, 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'id': 1040, 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'id': 1041, 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'id': 1042, 'synset': 'strap.n.01', 'synonyms': ['strap'], 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'id': 1043, 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'id': 1044, 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'id': 1045, 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'id': 1046, 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'id': 1047, 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'id': 1048, 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'def': 'a pointed tool for writing or drawing or engraving', 'name': 'stylus'}, {'frequency': 'r', 'id': 1049, 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'id': 1050, 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'id': 1051, 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'c', 'id': 1052, 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'id': 1053, 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'id': 1054, 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'id': 1055, 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'r', 'id': 1056, 'synset': 'sunscreen.n.01', 'synonyms': ['sunscreen', 'sunblock'], 'def': 'a cream spread on the skin; contains a chemical to filter out ultraviolet light and so protect from sunburn', 'name': 'sunscreen'}, {'frequency': 'f', 'id': 1057, 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'id': 1058, 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'id': 1059, 'synset': 'swab.n.02', 'synonyms': ['mop'], 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'id': 1060, 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'id': 1061, 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'id': 1062, 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'id': 1063, 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'id': 1064, 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'id': 1065, 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'id': 1066, 'synset': 'sword.n.01', 'synonyms': ['sword'], 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'id': 1067, 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'id': 1068, 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'id': 1069, 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'id': 1070, 'synset': 'table.n.02', 'synonyms': ['table'], 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'id': 1071, 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'id': 1072, 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'id': 1073, 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'id': 1074, 'synset': 'taco.n.02', 'synonyms': ['taco'], 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'id': 1075, 'synset': 'tag.n.02', 'synonyms': ['tag'], 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'id': 1076, 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'id': 1077, 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'id': 1078, 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'c', 'id': 1079, 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'id': 1080, 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'c', 'id': 1081, 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'id': 1082, 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'id': 1083, 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'id': 1084, 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'id': 1085, 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'id': 1086, 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'r', 'id': 1087, 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'id': 1088, 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'id': 1089, 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'c', 'id': 1090, 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'id': 1091, 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'id': 1092, 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'def': 'electronic device for communicating by voice over long distances', 'name': 'telephone'}, {'frequency': 'c', 'id': 1093, 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'id': 1094, 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'id': 1095, 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'id': 1096, 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'id': 1097, 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'id': 1098, 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'id': 1099, 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'id': 1100, 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'id': 1101, 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'id': 1102, 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'c', 'id': 1103, 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'id': 1104, 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'id': 1105, 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'id': 1106, 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'id': 1107, 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'id': 1108, 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'id': 1109, 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'id': 1110, 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'id': 1111, 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'r', 'id': 1112, 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'id': 1113, 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'id': 1114, 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'id': 1115, 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'c', 'id': 1116, 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'id': 1117, 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'id': 1118, 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'id': 1119, 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'c', 'id': 1120, 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'id': 1121, 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'id': 1122, 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'id': 1123, 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'c', 'id': 1124, 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'c', 'id': 1125, 'synset': 'top.n.09', 'synonyms': ['cover'], 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'id': 1126, 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'id': 1127, 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'id': 1128, 'synset': 'towel.n.01', 'synonyms': ['towel'], 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'id': 1129, 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'id': 1130, 'synset': 'toy.n.03', 'synonyms': ['toy'], 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'id': 1131, 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'id': 1132, 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'r', 'id': 1133, 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'c', 'id': 1134, 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'id': 1135, 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'id': 1136, 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'id': 1137, 'synset': 'tray.n.01', 'synonyms': ['tray'], 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'id': 1138, 'synset': 'tree_house.n.01', 'synonyms': ['tree_house'], 'def': '(NOT A TREE) a PLAYHOUSE built in the branches of a tree', 'name': 'tree_house'}, {'frequency': 'r', 'id': 1139, 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'id': 1140, 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'r', 'id': 1141, 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'c', 'id': 1142, 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'id': 1143, 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'id': 1144, 'synset': 'truck.n.01', 'synonyms': ['truck'], 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'id': 1145, 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'id': 1146, 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'id': 1147, 'synset': 'tub.n.02', 'synonyms': ['vat'], 'def': 'a large open vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'id': 1148, 'synset': 'turban.n.01', 'synonyms': ['turban'], 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'r', 'id': 1149, 'synset': 'turkey.n.01', 'synonyms': ['turkey_(bird)'], 'def': 'large gallinaceous bird with fan-shaped tail; widely domesticated for food', 'name': 'turkey_(bird)'}, {'frequency': 'c', 'id': 1150, 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'id': 1151, 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'id': 1152, 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'r', 'id': 1153, 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'r', 'id': 1154, 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'id': 1155, 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'c', 'id': 1156, 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'id': 1157, 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'c', 'id': 1158, 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'r', 'id': 1159, 'synset': 'urn.n.01', 'synonyms': ['urn'], 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'id': 1160, 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'c', 'id': 1161, 'synset': 'valve.n.03', 'synonyms': ['valve'], 'def': 'control consisting of a mechanical device for controlling the flow of a fluid', 'name': 'valve'}, {'frequency': 'f', 'id': 1162, 'synset': 'vase.n.01', 'synonyms': ['vase'], 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'id': 1163, 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'id': 1164, 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'c', 'id': 1165, 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'id': 1166, 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'id': 1167, 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'id': 1168, 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'r', 'id': 1169, 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'id': 1170, 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'id': 1171, 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'id': 1172, 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'id': 1173, 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'id': 1174, 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'id': 1175, 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'id': 1176, 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'id': 1177, 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'c', 'id': 1178, 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'id': 1179, 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'id': 1180, 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'id': 1181, 'synset': 'wasabi.n.02', 'synonyms': ['wasabi'], 'def': 'the thick green root of the wasabi plant that the Japanese use in cooking and that tastes like strong horseradish', 'name': 'wasabi'}, {'frequency': 'c', 'id': 1182, 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'id': 1183, 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'id': 1184, 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'id': 1185, 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'id': 1186, 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'id': 1187, 'synset': 'water_filter.n.01', 'synonyms': ['water_filter'], 'def': 'a filter to remove impurities from the water supply', 'name': 'water_filter'}, {'frequency': 'r', 'id': 1188, 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'r', 'id': 1189, 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'id': 1190, 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'id': 1191, 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'id': 1192, 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'id': 1193, 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'id': 1194, 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'c', 'id': 1195, 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'id': 1196, 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'id': 1197, 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'id': 1198, 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'id': 1199, 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'id': 1200, 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'id': 1201, 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'id': 1202, 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'id': 1203, 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'r', 'id': 1204, 'synset': 'whiskey.n.01', 'synonyms': ['whiskey'], 'def': 'a liquor made from fermented mash of grain', 'name': 'whiskey'}, {'frequency': 'r', 'id': 1205, 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'r', 'id': 1206, 'synset': 'wick.n.02', 'synonyms': ['wick'], 'def': 'a loosely woven cord in a candle or oil lamp that is lit on fire', 'name': 'wick'}, {'frequency': 'c', 'id': 1207, 'synset': 'wig.n.01', 'synonyms': ['wig'], 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'id': 1208, 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'id': 1209, 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'def': 'a mill that is powered by the wind', 'name': 'windmill'}, {'frequency': 'c', 'id': 1210, 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'id': 1211, 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'id': 1212, 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'id': 1213, 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'r', 'id': 1214, 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'id': 1215, 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'r', 'id': 1216, 'synset': 'wing_chair.n.01', 'synonyms': ['wing_chair'], 'def': 'easy chair having wings on each side of a high back', 'name': 'wing_chair'}, {'frequency': 'c', 'id': 1217, 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'id': 1218, 'synset': 'wok.n.01', 'synonyms': ['wok'], 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'id': 1219, 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'id': 1220, 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'id': 1221, 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'id': 1222, 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'c', 'id': 1223, 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'id': 1224, 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'r', 'id': 1225, 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'r', 'id': 1226, 'synset': 'yak.n.02', 'synonyms': ['yak'], 'def': 'large long-haired wild ox of Tibet often domesticated', 'name': 'yak'}, {'frequency': 'c', 'id': 1227, 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'r', 'id': 1228, 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'id': 1229, 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'id': 1230, 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa +# fmt: on diff --git a/data_processing/detectron2/detectron2/data/datasets/lvis_v1_categories.py b/data_processing/detectron2/detectron2/data/datasets/lvis_v1_categories.py new file mode 100644 index 0000000..7374e69 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/lvis_v1_categories.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v1_val.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["image_count"] +# del x["instance_count"] +# LVIS_CATEGORIES = repr(c) + " # noqa" +# with open("/tmp/lvis_categories.py", "wt") as f: +# f.write(f"LVIS_CATEGORIES = {LVIS_CATEGORIES}") +# Then paste the contents of that file below + +# fmt: off +LVIS_CATEGORIES = [{'frequency': 'c', 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'id': 1, 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'id': 2, 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'id': 3, 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'f', 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'id': 4, 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'id': 5, 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'c', 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'id': 6, 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'synset': 'almond.n.02', 'synonyms': ['almond'], 'id': 7, 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'id': 8, 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'c', 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'id': 9, 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'id': 10, 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'id': 11, 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'synset': 'apple.n.01', 'synonyms': ['apple'], 'id': 12, 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'id': 13, 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'id': 14, 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'synset': 'apron.n.01', 'synonyms': ['apron'], 'id': 15, 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'id': 16, 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'r', 'synset': 'arctic.n.02', 'synonyms': ['arctic_(type_of_shoe)', 'galosh', 'golosh', 'rubber_(type_of_shoe)', 'gumshoe'], 'id': 17, 'def': 'a waterproof overshoe that protects shoes from water or snow', 'name': 'arctic_(type_of_shoe)'}, {'frequency': 'c', 'synset': 'armband.n.02', 'synonyms': ['armband'], 'id': 18, 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'id': 19, 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'id': 20, 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'id': 21, 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'id': 22, 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'id': 23, 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'id': 24, 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'id': 25, 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'id': 26, 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'f', 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'id': 27, 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'id': 28, 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'synset': 'awning.n.01', 'synonyms': ['awning'], 'id': 29, 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'id': 30, 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'r', 'synset': 'baboon.n.01', 'synonyms': ['baboon'], 'id': 31, 'def': 'large terrestrial monkeys having doglike muzzles', 'name': 'baboon'}, {'frequency': 'f', 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'id': 32, 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'id': 33, 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'id': 34, 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'id': 35, 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'id': 36, 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'id': 37, 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'id': 38, 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'id': 39, 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'id': 40, 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'synset': 'ball.n.06', 'synonyms': ['ball'], 'id': 41, 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'id': 42, 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'id': 43, 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'id': 44, 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'synset': 'banana.n.02', 'synonyms': ['banana'], 'id': 45, 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'c', 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'id': 46, 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'id': 47, 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'f', 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'id': 48, 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'id': 49, 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'id': 50, 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'id': 51, 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'synset': 'barge.n.01', 'synonyms': ['barge'], 'id': 52, 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'id': 53, 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'id': 54, 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'id': 55, 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'id': 56, 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'id': 57, 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'id': 58, 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'id': 59, 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'id': 60, 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'id': 61, 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'id': 62, 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'id': 63, 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'c', 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'id': 64, 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'id': 65, 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'id': 66, 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'id': 67, 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'id': 68, 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'id': 69, 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'synset': 'battery.n.02', 'synonyms': ['battery'], 'id': 70, 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'id': 71, 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'synset': 'bead.n.01', 'synonyms': ['bead'], 'id': 72, 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'c', 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'id': 73, 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'id': 74, 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'id': 75, 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'synset': 'bear.n.01', 'synonyms': ['bear'], 'id': 76, 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'synset': 'bed.n.01', 'synonyms': ['bed'], 'id': 77, 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'r', 'synset': 'bedpan.n.01', 'synonyms': ['bedpan'], 'id': 78, 'def': 'a shallow vessel used by a bedridden patient for defecation and urination', 'name': 'bedpan'}, {'frequency': 'f', 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'id': 79, 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'synset': 'beef.n.01', 'synonyms': ['cow'], 'id': 80, 'def': 'cattle/cow', 'name': 'cow'}, {'frequency': 'f', 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'id': 81, 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'id': 82, 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'id': 83, 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'id': 84, 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'id': 85, 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'synset': 'bell.n.01', 'synonyms': ['bell'], 'id': 86, 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'id': 87, 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'synset': 'belt.n.02', 'synonyms': ['belt'], 'id': 88, 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'id': 89, 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'synset': 'bench.n.01', 'synonyms': ['bench'], 'id': 90, 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'synset': 'beret.n.01', 'synonyms': ['beret'], 'id': 91, 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'synset': 'bib.n.02', 'synonyms': ['bib'], 'id': 92, 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'id': 93, 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'id': 94, 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'id': 95, 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'f', 'synset': 'billboard.n.01', 'synonyms': ['billboard'], 'id': 96, 'def': 'large outdoor signboard', 'name': 'billboard'}, {'frequency': 'c', 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'id': 97, 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'id': 98, 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'synset': 'bird.n.01', 'synonyms': ['bird'], 'id': 99, 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'c', 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'id': 100, 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'c', 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'id': 101, 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'id': 102, 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'id': 103, 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'id': 104, 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'id': 105, 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'id': 106, 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'id': 107, 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'synset': 'blackberry.n.01', 'synonyms': ['blackberry'], 'id': 108, 'def': 'large sweet black or very dark purple edible aggregate fruit', 'name': 'blackberry'}, {'frequency': 'f', 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'id': 109, 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'id': 110, 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'id': 111, 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'id': 112, 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'id': 113, 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'f', 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'id': 114, 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'f', 'synset': 'blouse.n.01', 'synonyms': ['blouse'], 'id': 115, 'def': 'a top worn by women', 'name': 'blouse'}, {'frequency': 'f', 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'id': 116, 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'id': 117, 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'id': 118, 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'r', 'synset': 'bob.n.05', 'synonyms': ['bob', 'bobber', 'bobfloat'], 'id': 119, 'def': 'a small float usually made of cork; attached to a fishing line', 'name': 'bob'}, {'frequency': 'c', 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'id': 120, 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'c', 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'id': 121, 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'id': 122, 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'id': 123, 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'id': 124, 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'id': 125, 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'id': 126, 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'synset': 'book.n.01', 'synonyms': ['book'], 'id': 127, 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'c', 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'id': 128, 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'id': 129, 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'id': 130, 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'id': 131, 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'synset': 'boot.n.01', 'synonyms': ['boot'], 'id': 132, 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'id': 133, 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'id': 134, 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'id': 135, 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'id': 136, 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'id': 137, 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'id': 138, 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'id': 139, 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'id': 140, 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'id': 141, 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'id': 142, 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'f', 'synset': 'box.n.01', 'synonyms': ['box'], 'id': 143, 'def': 'a (usually rectangular) container; may have a lid', 'name': 'box'}, {'frequency': 'r', 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'id': 144, 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'id': 145, 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'id': 146, 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'id': 147, 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'id': 148, 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'id': 149, 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'f', 'synset': 'bread.n.01', 'synonyms': ['bread'], 'id': 150, 'def': 'food made from dough of flour or meal and usually raised with yeast or baking powder and then baked', 'name': 'bread'}, {'frequency': 'r', 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'id': 151, 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'f', 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'id': 152, 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'id': 153, 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'f', 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'id': 154, 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'id': 155, 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'synset': 'broom.n.01', 'synonyms': ['broom'], 'id': 156, 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'id': 157, 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'id': 158, 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'id': 159, 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'id': 160, 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'id': 161, 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'synset': 'bull.n.11', 'synonyms': ['horned_cow'], 'id': 162, 'def': 'a cow with horns', 'name': 'bull'}, {'frequency': 'c', 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'id': 163, 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'id': 164, 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'id': 165, 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'id': 166, 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'id': 167, 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'id': 168, 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'f', 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'id': 169, 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'id': 170, 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'id': 171, 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'id': 172, 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'id': 173, 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'id': 174, 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'f', 'synset': 'butter.n.01', 'synonyms': ['butter'], 'id': 175, 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'id': 176, 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'synset': 'button.n.01', 'synonyms': ['button'], 'id': 177, 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'id': 178, 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'id': 179, 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'c', 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'id': 180, 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'id': 181, 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'id': 182, 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'synset': 'cake.n.03', 'synonyms': ['cake'], 'id': 183, 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'id': 184, 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'id': 185, 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'synset': 'calf.n.01', 'synonyms': ['calf'], 'id': 186, 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'id': 187, 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'synset': 'camel.n.01', 'synonyms': ['camel'], 'id': 188, 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'synset': 'camera.n.01', 'synonyms': ['camera'], 'id': 189, 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'id': 190, 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'id': 191, 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'id': 192, 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'id': 193, 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'f', 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'id': 194, 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'id': 195, 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'id': 196, 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'id': 197, 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'id': 198, 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'id': 199, 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'c', 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'id': 200, 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'c', 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'id': 201, 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'id': 202, 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'f', 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'id': 203, 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'id': 204, 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'c', 'synset': 'cape.n.02', 'synonyms': ['cape'], 'id': 205, 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'id': 206, 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'id': 207, 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'id': 208, 'def': 'a wheeled vehicle adapted to the rails of railroad (mark each individual railcar separately)', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'id': 209, 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'id': 210, 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'id': 211, 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'synset': 'card.n.03', 'synonyms': ['card'], 'id': 212, 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'c', 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'id': 213, 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'id': 214, 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'id': 215, 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'id': 216, 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'id': 217, 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'f', 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'id': 218, 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'synset': 'cart.n.01', 'synonyms': ['cart'], 'id': 219, 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'synset': 'carton.n.02', 'synonyms': ['carton'], 'id': 220, 'def': 'a container made of cardboard for holding food or drink', 'name': 'carton'}, {'frequency': 'c', 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'id': 221, 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'id': 222, 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'id': 223, 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'id': 224, 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'synset': 'cat.n.01', 'synonyms': ['cat'], 'id': 225, 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'f', 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'id': 226, 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'c', 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'id': 227, 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'id': 228, 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'f', 'synset': 'celery.n.01', 'synonyms': ['celery'], 'id': 229, 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'id': 230, 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'id': 231, 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'synset': 'chair.n.01', 'synonyms': ['chair'], 'id': 232, 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'id': 233, 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'synset': 'chalice.n.01', 'synonyms': ['chalice'], 'id': 234, 'def': 'a bowl-shaped drinking vessel; especially the Eucharistic cup', 'name': 'chalice'}, {'frequency': 'f', 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'id': 235, 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'synset': 'chap.n.04', 'synonyms': ['chap'], 'id': 236, 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'id': 237, 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'id': 238, 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'id': 239, 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'id': 240, 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'c', 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'id': 241, 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'id': 242, 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'c', 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'id': 243, 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'id': 244, 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'id': 245, 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'id': 246, 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'id': 247, 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'id': 248, 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'id': 249, 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'id': 250, 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'id': 251, 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'id': 252, 'def': 'shirt collar, animal collar, or tight-fitting necklace', 'name': 'choker'}, {'frequency': 'f', 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'id': 253, 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'f', 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'id': 254, 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'id': 255, 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'synset': 'chute.n.02', 'synonyms': ['slide'], 'id': 256, 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'id': 257, 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'id': 258, 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'f', 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'id': 259, 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'id': 260, 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'id': 261, 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'id': 262, 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'c', 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'id': 263, 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'id': 264, 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'synset': 'cleat.n.02', 'synonyms': ['cleat_(for_securing_rope)'], 'id': 265, 'def': 'a fastener (usually with two projecting horns) around which a rope can be secured', 'name': 'cleat_(for_securing_rope)'}, {'frequency': 'r', 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'id': 266, 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'synset': 'clip.n.03', 'synonyms': ['clip'], 'id': 267, 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'id': 268, 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'r', 'synset': 'clipper.n.03', 'synonyms': ['clippers_(for_plants)'], 'id': 269, 'def': 'shears for cutting grass or shrubbery (often used in the plural)', 'name': 'clippers_(for_plants)'}, {'frequency': 'r', 'synset': 'cloak.n.02', 'synonyms': ['cloak'], 'id': 270, 'def': 'a loose outer garment', 'name': 'cloak'}, {'frequency': 'f', 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'id': 271, 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'id': 272, 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'id': 273, 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'id': 274, 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'id': 275, 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'id': 276, 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'synset': 'coat.n.01', 'synonyms': ['coat'], 'id': 277, 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'id': 278, 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'c', 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'id': 279, 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'id': 280, 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'r', 'synset': 'cockroach.n.01', 'synonyms': ['cockroach'], 'id': 281, 'def': 'any of numerous chiefly nocturnal insects; some are domestic pests', 'name': 'cockroach'}, {'frequency': 'r', 'synset': 'cocoa.n.01', 'synonyms': ['cocoa_(beverage)', 'hot_chocolate_(beverage)', 'drinking_chocolate'], 'id': 282, 'def': 'a beverage made from cocoa powder and milk and sugar; usually drunk hot', 'name': 'cocoa_(beverage)'}, {'frequency': 'c', 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'id': 283, 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'f', 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'id': 284, 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'id': 285, 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'id': 286, 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'synset': 'coil.n.05', 'synonyms': ['coil'], 'id': 287, 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'synset': 'coin.n.01', 'synonyms': ['coin'], 'id': 288, 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'c', 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'id': 289, 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'id': 290, 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'id': 291, 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'id': 292, 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'id': 293, 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'id': 294, 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'r', 'synset': 'compass.n.01', 'synonyms': ['compass'], 'id': 295, 'def': 'navigational instrument for finding directions', 'name': 'compass'}, {'frequency': 'f', 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'id': 296, 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'f', 'synset': 'condiment.n.01', 'synonyms': ['condiment'], 'id': 297, 'def': 'a preparation (a sauce or relish or spice) to enhance flavor or enjoyment', 'name': 'condiment'}, {'frequency': 'f', 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'id': 298, 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'id': 299, 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'id': 300, 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'id': 301, 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'r', 'synset': 'cooker.n.01', 'synonyms': ['cooker'], 'id': 302, 'def': 'a utensil for cooking', 'name': 'cooker'}, {'frequency': 'f', 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'id': 303, 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'id': 304, 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'id': 305, 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'f', 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'id': 306, 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'id': 307, 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'c', 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'id': 308, 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'f', 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'id': 309, 'def': 'ears or kernels of corn that can be prepared and served for human food (only mark individual ears or kernels)', 'name': 'edible_corn'}, {'frequency': 'r', 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'id': 310, 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'id': 311, 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'id': 312, 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'id': 313, 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'c', 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'id': 314, 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'c', 'synset': 'costume.n.04', 'synonyms': ['costume'], 'id': 315, 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'id': 316, 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'id': 317, 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'c', 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'id': 318, 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'id': 319, 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'c', 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'id': 320, 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'r', 'synset': 'crab.n.05', 'synonyms': ['crabmeat'], 'id': 321, 'def': 'the edible flesh of any of various crabs', 'name': 'crabmeat'}, {'frequency': 'c', 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'id': 322, 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'id': 323, 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'synset': 'crate.n.01', 'synonyms': ['crate'], 'id': 324, 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'c', 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'id': 325, 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'id': 326, 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'c', 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'id': 327, 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'id': 328, 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'id': 329, 'def': 'an earthen jar (made of baked clay) or a modern electric crockpot', 'name': 'crock_pot'}, {'frequency': 'f', 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'id': 330, 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'id': 331, 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'c', 'synset': 'crow.n.01', 'synonyms': ['crow'], 'id': 332, 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'r', 'synset': 'crowbar.n.01', 'synonyms': ['crowbar', 'wrecking_bar', 'pry_bar'], 'id': 333, 'def': 'a heavy iron lever with one end forged into a wedge', 'name': 'crowbar'}, {'frequency': 'c', 'synset': 'crown.n.04', 'synonyms': ['crown'], 'id': 334, 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'id': 335, 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'id': 336, 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'id': 337, 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'f', 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'id': 338, 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'c', 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'id': 339, 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'id': 340, 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'c', 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'id': 341, 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'id': 342, 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'id': 343, 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'synset': 'cup.n.01', 'synonyms': ['cup'], 'id': 344, 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'id': 345, 'def': 'a metal award or cup-shaped vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'f', 'synset': 'cupboard.n.01', 'synonyms': ['cupboard', 'closet'], 'id': 346, 'def': 'a small room (or recess) or cabinet used for storage space', 'name': 'cupboard'}, {'frequency': 'f', 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'id': 347, 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'id': 348, 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'id': 349, 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'id': 350, 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'id': 351, 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'id': 352, 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'id': 353, 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'id': 354, 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'synset': 'dalmatian.n.02', 'synonyms': ['dalmatian'], 'id': 355, 'def': 'a large breed having a smooth white coat with black or brown spots', 'name': 'dalmatian'}, {'frequency': 'c', 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'id': 356, 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'id': 357, 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'id': 358, 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'id': 359, 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'id': 360, 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'synset': 'desk.n.01', 'synonyms': ['desk'], 'id': 361, 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'id': 362, 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'id': 363, 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'id': 364, 'def': 'yearly planner book', 'name': 'diary'}, {'frequency': 'r', 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'id': 365, 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'id': 366, 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'id': 367, 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'id': 368, 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'f', 'synset': 'dish.n.01', 'synonyms': ['dish'], 'id': 369, 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'id': 370, 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'id': 371, 'def': 'a cloth for washing dishes or cleaning in general', 'name': 'dishrag'}, {'frequency': 'f', 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'id': 372, 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'id': 373, 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid', 'dishsoap'], 'id': 374, 'def': 'dishsoap or dish detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'f', 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'id': 375, 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'r', 'synset': 'diving_board.n.01', 'synonyms': ['diving_board'], 'id': 376, 'def': 'a springboard from which swimmers can dive', 'name': 'diving_board'}, {'frequency': 'f', 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'id': 377, 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'synset': 'dog.n.01', 'synonyms': ['dog'], 'id': 378, 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'id': 379, 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'f', 'synset': 'doll.n.01', 'synonyms': ['doll'], 'id': 380, 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'id': 381, 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'synset': 'dollhouse.n.01', 'synonyms': ['dollhouse', "doll's_house"], 'id': 382, 'def': "a house so small that it is likened to a child's plaything", 'name': 'dollhouse'}, {'frequency': 'c', 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'id': 383, 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'id': 384, 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'f', 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'id': 385, 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'id': 386, 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'id': 387, 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'synset': 'dove.n.01', 'synonyms': ['dove'], 'id': 388, 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'id': 389, 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'id': 390, 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'id': 391, 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'id': 392, 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'id': 393, 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'f', 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'id': 394, 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'f', 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'id': 395, 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'synset': 'drill.n.01', 'synonyms': ['drill'], 'id': 396, 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'synset': 'drone.n.04', 'synonyms': ['drone'], 'id': 397, 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'id': 398, 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'id': 399, 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'id': 400, 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'synset': 'duck.n.01', 'synonyms': ['duck'], 'id': 401, 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'c', 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'id': 402, 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'id': 403, 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'id': 404, 'def': 'a large cylindrical bag of heavy cloth (does not include suitcases)', 'name': 'duffel_bag'}, {'frequency': 'r', 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'id': 405, 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'id': 406, 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'id': 407, 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'c', 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'id': 408, 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'id': 409, 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'id': 410, 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'synset': 'earring.n.01', 'synonyms': ['earring'], 'id': 411, 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'synset': 'easel.n.01', 'synonyms': ['easel'], 'id': 412, 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'id': 413, 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'synset': 'eel.n.01', 'synonyms': ['eel'], 'id': 414, 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'id': 415, 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'id': 416, 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'id': 417, 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'id': 418, 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'id': 419, 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'id': 420, 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'id': 421, 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'id': 422, 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'c', 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'id': 423, 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'id': 424, 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'id': 425, 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'id': 426, 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'id': 427, 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'id': 428, 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'synset': 'fan.n.01', 'synonyms': ['fan'], 'id': 429, 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'id': 430, 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'id': 431, 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'id': 432, 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'id': 433, 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'c', 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'id': 434, 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'id': 435, 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'id': 436, 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'id': 437, 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'id': 438, 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'id': 439, 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'id': 440, 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'f', 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'id': 441, 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'f', 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'id': 442, 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'id': 443, 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'id': 444, 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'id': 445, 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'r', 'synset': 'first-aid_kit.n.01', 'synonyms': ['first-aid_kit'], 'id': 446, 'def': 'kit consisting of a set of bandages and medicines for giving first aid', 'name': 'first-aid_kit'}, {'frequency': 'f', 'synset': 'fish.n.01', 'synonyms': ['fish'], 'id': 447, 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'c', 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'id': 448, 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'id': 449, 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'c', 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'id': 450, 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'synset': 'flag.n.01', 'synonyms': ['flag'], 'id': 451, 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'id': 452, 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'id': 453, 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'id': 454, 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'c', 'synset': 'flap.n.01', 'synonyms': ['flap'], 'id': 455, 'def': 'any broad thin covering attached at one edge, such as a mud flap next to a wheel or a flap on an airplane wing', 'name': 'flap'}, {'frequency': 'r', 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'id': 456, 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'id': 457, 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'id': 458, 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'id': 459, 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'id': 460, 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'id': 461, 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'id': 462, 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'c', 'synset': 'foal.n.01', 'synonyms': ['foal'], 'id': 463, 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'id': 464, 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'id': 465, 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'id': 466, 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'id': 467, 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'id': 468, 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'synset': 'fork.n.01', 'synonyms': ['fork'], 'id': 469, 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'c', 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'id': 470, 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'c', 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'id': 471, 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'c', 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'id': 472, 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'id': 473, 'def': 'anything that freshens air by removing or covering odor', 'name': 'freshener'}, {'frequency': 'f', 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'id': 474, 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'id': 475, 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'id': 476, 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'f', 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'id': 477, 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'id': 478, 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'id': 479, 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'r', 'synset': 'futon.n.01', 'synonyms': ['futon'], 'id': 480, 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'id': 481, 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'id': 482, 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'id': 483, 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'id': 484, 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'id': 485, 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'id': 486, 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'id': 487, 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'id': 488, 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'c', 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'id': 489, 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'id': 490, 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'id': 491, 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'r', 'synset': 'generator.n.02', 'synonyms': ['generator'], 'id': 492, 'def': 'engine that converts mechanical energy into electrical energy by electromagnetic induction', 'name': 'generator'}, {'frequency': 'c', 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'id': 493, 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'id': 494, 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'id': 495, 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'id': 496, 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'id': 497, 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'id': 498, 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'synset': 'globe.n.03', 'synonyms': ['globe'], 'id': 499, 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'synset': 'glove.n.02', 'synonyms': ['glove'], 'id': 500, 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'synset': 'goat.n.01', 'synonyms': ['goat'], 'id': 501, 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'id': 502, 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'id': 503, 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'c', 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'id': 504, 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'id': 505, 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'id': 506, 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'synset': 'goose.n.01', 'synonyms': ['goose'], 'id': 507, 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'id': 508, 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'id': 509, 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'f', 'synset': 'grape.n.01', 'synonyms': ['grape'], 'id': 510, 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'c', 'synset': 'grater.n.01', 'synonyms': ['grater'], 'id': 511, 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'id': 512, 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'id': 513, 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'f', 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'id': 514, 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'f', 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'id': 515, 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'id': 516, 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'f', 'synset': 'grill.n.02', 'synonyms': ['grill', 'grille', 'grillwork', 'radiator_grille'], 'id': 517, 'def': 'a framework of metal bars used as a partition or a grate', 'name': 'grill'}, {'frequency': 'r', 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'id': 518, 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'id': 519, 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'id': 520, 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'f', 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'id': 521, 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'id': 522, 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'synset': 'gun.n.01', 'synonyms': ['gun'], 'id': 523, 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'f', 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'id': 524, 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'id': 525, 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'id': 526, 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'r', 'synset': 'halter.n.03', 'synonyms': ['halter_top'], 'id': 527, 'def': "a woman's top that fastens behind the back and neck leaving the back and arms uncovered", 'name': 'halter_top'}, {'frequency': 'f', 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'id': 528, 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'id': 529, 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'id': 530, 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'c', 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'id': 531, 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'id': 532, 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'c', 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'id': 533, 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'f', 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'id': 534, 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'id': 535, 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'id': 536, 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'id': 537, 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'id': 538, 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'id': 539, 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'id': 540, 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'id': 541, 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'id': 542, 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'id': 543, 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'synset': 'hat.n.01', 'synonyms': ['hat'], 'id': 544, 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'id': 545, 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'c', 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'id': 546, 'def': 'a garment that covers the head OR face', 'name': 'veil'}, {'frequency': 'f', 'synset': 'headband.n.01', 'synonyms': ['headband'], 'id': 547, 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'id': 548, 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'id': 549, 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'id': 550, 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'synset': 'headset.n.01', 'synonyms': ['headset'], 'id': 551, 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'id': 552, 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'c', 'synset': 'heart.n.02', 'synonyms': ['heart'], 'id': 553, 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'id': 554, 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'id': 555, 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'id': 556, 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'synset': 'heron.n.02', 'synonyms': ['heron'], 'id': 557, 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'id': 558, 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'id': 559, 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'id': 560, 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'id': 561, 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'id': 562, 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'id': 563, 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'synset': 'honey.n.01', 'synonyms': ['honey'], 'id': 564, 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'id': 565, 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'synset': 'hook.n.05', 'synonyms': ['hook'], 'id': 566, 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'r', 'synset': 'hookah.n.01', 'synonyms': ['hookah', 'narghile', 'nargileh', 'sheesha', 'shisha', 'water_pipe'], 'id': 567, 'def': 'a tobacco pipe with a long flexible tube connected to a container where the smoke is cooled by passing through water', 'name': 'hookah'}, {'frequency': 'r', 'synset': 'hornet.n.01', 'synonyms': ['hornet'], 'id': 568, 'def': 'large stinging wasp', 'name': 'hornet'}, {'frequency': 'f', 'synset': 'horse.n.01', 'synonyms': ['horse'], 'id': 569, 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'id': 570, 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'id': 571, 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'id': 572, 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'id': 573, 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'id': 574, 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'id': 575, 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'c', 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'id': 576, 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'id': 577, 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'f', 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'id': 578, 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'id': 579, 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'id': 580, 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'id': 581, 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'id': 582, 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'id': 583, 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'c', 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'id': 584, 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'id': 585, 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'f', 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'id': 586, 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'id': 587, 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'c', 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'id': 588, 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'id': 589, 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'c', 'synset': 'jam.n.01', 'synonyms': ['jam'], 'id': 590, 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'synset': 'jar.n.01', 'synonyms': ['jar'], 'id': 591, 'def': 'a vessel (usually cylindrical) with a wide mouth and without handles', 'name': 'jar'}, {'frequency': 'f', 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'id': 592, 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'id': 593, 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'id': 594, 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'id': 595, 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'id': 596, 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'r', 'synset': 'jewel.n.01', 'synonyms': ['jewel', 'gem', 'precious_stone'], 'id': 597, 'def': 'a precious or semiprecious stone incorporated into a piece of jewelry', 'name': 'jewel'}, {'frequency': 'c', 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'id': 598, 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'id': 599, 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'c', 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'id': 600, 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'id': 601, 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'synset': 'keg.n.02', 'synonyms': ['keg'], 'id': 602, 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'id': 603, 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'id': 604, 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'synset': 'key.n.01', 'synonyms': ['key'], 'id': 605, 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'id': 606, 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'c', 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'id': 607, 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'id': 608, 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'id': 609, 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'r', 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'id': 610, 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'synset': 'kite.n.03', 'synonyms': ['kite'], 'id': 611, 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'id': 612, 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'id': 613, 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'id': 614, 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'synset': 'knife.n.01', 'synonyms': ['knife'], 'id': 615, 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'id': 616, 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'synset': 'knob.n.02', 'synonyms': ['knob'], 'id': 617, 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'id': 618, 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'id': 619, 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'id': 620, 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'id': 621, 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'id': 622, 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'c', 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'id': 623, 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'f', 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'id': 624, 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'id': 625, 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'id': 626, 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'id': 627, 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'id': 628, 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'id': 629, 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'id': 630, 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'id': 631, 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'id': 632, 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'f', 'synset': 'latch.n.02', 'synonyms': ['latch'], 'id': 633, 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'id': 634, 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'synset': 'leather.n.01', 'synonyms': ['leather'], 'id': 635, 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'id': 636, 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'id': 637, 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'r', 'synset': 'legume.n.02', 'synonyms': ['legume'], 'id': 638, 'def': 'the fruit or seed of bean or pea plants', 'name': 'legume'}, {'frequency': 'f', 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'id': 639, 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'id': 640, 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'id': 641, 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'id': 642, 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'id': 643, 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'id': 644, 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'id': 645, 'def': 'lightblub/source of light', 'name': 'lightbulb'}, {'frequency': 'r', 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'id': 646, 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'f', 'synset': 'lime.n.06', 'synonyms': ['lime'], 'id': 647, 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'id': 648, 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'c', 'synset': 'lion.n.01', 'synonyms': ['lion'], 'id': 649, 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'id': 650, 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'r', 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'id': 651, 'def': 'liquor or beer', 'name': 'liquor'}, {'frequency': 'c', 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'id': 652, 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'f', 'synset': 'log.n.01', 'synonyms': ['log'], 'id': 653, 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'id': 654, 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'f', 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'id': 655, 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'id': 656, 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'id': 657, 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'id': 658, 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'id': 659, 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'c', 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'id': 660, 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'f', 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'id': 661, 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'synset': 'mallard.n.01', 'synonyms': ['mallard'], 'id': 662, 'def': 'wild dabbling duck from which domestic ducks are descended', 'name': 'mallard'}, {'frequency': 'r', 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'id': 663, 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'id': 664, 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'r', 'synset': 'manatee.n.01', 'synonyms': ['manatee'], 'id': 665, 'def': 'sirenian mammal of tropical coastal waters of America', 'name': 'manatee'}, {'frequency': 'c', 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'id': 666, 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'id': 667, 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'id': 668, 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'f', 'synset': 'map.n.01', 'synonyms': ['map'], 'id': 669, 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'f', 'synset': 'marker.n.03', 'synonyms': ['marker'], 'id': 670, 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'synset': 'martini.n.01', 'synonyms': ['martini'], 'id': 671, 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'id': 672, 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'id': 673, 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'synset': 'masher.n.02', 'synonyms': ['masher'], 'id': 674, 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'id': 675, 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'synset': 'mast.n.01', 'synonyms': ['mast'], 'id': 676, 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'id': 677, 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'id': 678, 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'id': 679, 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'id': 680, 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'id': 681, 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'id': 682, 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'id': 683, 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'c', 'synset': 'melon.n.01', 'synonyms': ['melon'], 'id': 684, 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'id': 685, 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'id': 686, 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'id': 687, 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'id': 688, 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'f', 'synset': 'milk.n.01', 'synonyms': ['milk'], 'id': 689, 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'r', 'synset': 'milk_can.n.01', 'synonyms': ['milk_can'], 'id': 690, 'def': 'can for transporting milk', 'name': 'milk_can'}, {'frequency': 'r', 'synset': 'milkshake.n.01', 'synonyms': ['milkshake'], 'id': 691, 'def': 'frothy drink of milk and flavoring and sometimes fruit or ice cream', 'name': 'milkshake'}, {'frequency': 'f', 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'id': 692, 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'id': 693, 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'id': 694, 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'id': 695, 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'id': 696, 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'synset': 'money.n.03', 'synonyms': ['money'], 'id': 697, 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'id': 698, 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'id': 699, 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'synset': 'motor.n.01', 'synonyms': ['motor'], 'id': 700, 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'id': 701, 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'id': 702, 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'f', 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'id': 703, 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'id': 704, 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'f', 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'id': 705, 'def': 'a computer input device that controls an on-screen pointer (does not include trackpads / touchpads)', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'id': 706, 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'id': 707, 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'synset': 'mug.n.04', 'synonyms': ['mug'], 'id': 708, 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'id': 709, 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'id': 710, 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'c', 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'id': 711, 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'id': 712, 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'f', 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'id': 713, 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'id': 714, 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'id': 715, 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'id': 716, 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'c', 'synset': 'needle.n.03', 'synonyms': ['needle'], 'id': 717, 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'synset': 'nest.n.01', 'synonyms': ['nest'], 'id': 718, 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'f', 'synset': 'newspaper.n.01', 'synonyms': ['newspaper', 'paper_(newspaper)'], 'id': 719, 'def': 'a daily or weekly publication on folded sheets containing news, articles, and advertisements', 'name': 'newspaper'}, {'frequency': 'c', 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'id': 720, 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'id': 721, 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'id': 722, 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'c', 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'id': 723, 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'id': 724, 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'id': 725, 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'f', 'synset': 'nut.n.03', 'synonyms': ['nut'], 'id': 726, 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'id': 727, 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'f', 'synset': 'oar.n.01', 'synonyms': ['oar'], 'id': 728, 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'id': 729, 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'id': 730, 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'id': 731, 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'id': 732, 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'id': 733, 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'synset': 'onion.n.01', 'synonyms': ['onion'], 'id': 734, 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'id': 735, 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'id': 736, 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'c', 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'id': 737, 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'f', 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'id': 738, 'def': 'a thick standalone cushion used as a seat or footrest, often next to a chair', 'name': 'ottoman'}, {'frequency': 'f', 'synset': 'oven.n.01', 'synonyms': ['oven'], 'id': 739, 'def': 'kitchen appliance used for baking or roasting', 'name': 'oven'}, {'frequency': 'c', 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'id': 740, 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'synset': 'owl.n.01', 'synonyms': ['owl'], 'id': 741, 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'synset': 'packet.n.03', 'synonyms': ['packet'], 'id': 742, 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'id': 743, 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'synset': 'pad.n.04', 'synonyms': ['pad'], 'id': 744, 'def': 'mostly arm/knee pads labeled', 'name': 'pad'}, {'frequency': 'f', 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'id': 745, 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'id': 746, 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'c', 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'id': 747, 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'synset': 'painting.n.01', 'synonyms': ['painting'], 'id': 748, 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'f', 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'id': 749, 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'id': 750, 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'id': 751, 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'id': 752, 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'id': 753, 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'id': 754, 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'id': 755, 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'f', 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'id': 756, 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'id': 757, 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'id': 758, 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'id': 759, 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'id': 760, 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'c', 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'id': 761, 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'id': 762, 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'c', 'synset': 'parasol.n.01', 'synonyms': ['parasol', 'sunshade'], 'id': 763, 'def': 'a handheld collapsible source of shade', 'name': 'parasol'}, {'frequency': 'r', 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'id': 764, 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'c', 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'id': 765, 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'id': 766, 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'id': 767, 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'id': 768, 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'id': 769, 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'c', 'synset': 'passport.n.02', 'synonyms': ['passport'], 'id': 770, 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'id': 771, 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'id': 772, 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'id': 773, 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'synset': 'peach.n.03', 'synonyms': ['peach'], 'id': 774, 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'id': 775, 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'f', 'synset': 'pear.n.01', 'synonyms': ['pear'], 'id': 776, 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'c', 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'id': 777, 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'synset': 'peg.n.04', 'synonyms': ['wooden_leg', 'pegleg'], 'id': 778, 'def': 'a prosthesis that replaces a missing leg', 'name': 'wooden_leg'}, {'frequency': 'r', 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'id': 779, 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'id': 780, 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'synset': 'pen.n.01', 'synonyms': ['pen'], 'id': 781, 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'f', 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'id': 782, 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'id': 783, 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'id': 784, 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'id': 785, 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'id': 786, 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'id': 787, 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'id': 788, 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'f', 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'id': 789, 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'id': 790, 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'id': 791, 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'id': 792, 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'synset': 'person.n.01', 'synonyms': ['person', 'baby', 'child', 'boy', 'girl', 'man', 'woman', 'human'], 'id': 793, 'def': 'a human being', 'name': 'person'}, {'frequency': 'c', 'synset': 'pet.n.01', 'synonyms': ['pet'], 'id': 794, 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'c', 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'id': 795, 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'id': 796, 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'id': 797, 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'f', 'synset': 'piano.n.01', 'synonyms': ['piano'], 'id': 798, 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'id': 799, 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'id': 800, 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'synset': 'pie.n.01', 'synonyms': ['pie'], 'id': 801, 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'id': 802, 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'id': 803, 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'id': 804, 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'id': 805, 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'id': 806, 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'id': 807, 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'id': 808, 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'id': 809, 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'id': 810, 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'id': 811, 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'id': 812, 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'c', 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'id': 813, 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'id': 814, 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'id': 815, 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'id': 816, 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'id': 817, 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'synset': 'plate.n.04', 'synonyms': ['plate'], 'id': 818, 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'synset': 'platter.n.01', 'synonyms': ['platter'], 'id': 819, 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'id': 820, 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'id': 821, 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'id': 822, 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'synset': 'plume.n.02', 'synonyms': ['plume'], 'id': 823, 'def': 'a feather or cluster of feathers worn as an ornament', 'name': 'plume'}, {'frequency': 'r', 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'id': 824, 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'id': 825, 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'id': 826, 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'id': 827, 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'f', 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'id': 828, 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'id': 829, 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'synset': 'pony.n.05', 'synonyms': ['pony'], 'id': 830, 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'id': 831, 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'id': 832, 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'c', 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'id': 833, 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'id': 834, 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'id': 835, 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'synset': 'pot.n.01', 'synonyms': ['pot'], 'id': 836, 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'id': 837, 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'synset': 'potato.n.01', 'synonyms': ['potato'], 'id': 838, 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'id': 839, 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'id': 840, 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'id': 841, 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'c', 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'id': 842, 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'id': 843, 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'c', 'synset': 'pretzel.n.01', 'synonyms': ['pretzel'], 'id': 844, 'def': 'glazed and salted cracker typically in the shape of a loose knot', 'name': 'pretzel'}, {'frequency': 'f', 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'id': 845, 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'id': 846, 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'synset': 'projector.n.02', 'synonyms': ['projector'], 'id': 847, 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'id': 848, 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'synset': 'prune.n.01', 'synonyms': ['prune'], 'id': 849, 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'id': 850, 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'id': 851, 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'id': 852, 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'id': 853, 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'id': 854, 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'id': 855, 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'id': 856, 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'c', 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'id': 857, 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'id': 858, 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'id': 859, 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'id': 860, 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'id': 861, 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'id': 862, 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'id': 863, 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'synset': 'radar.n.01', 'synonyms': ['radar'], 'id': 864, 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'f', 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'id': 865, 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'id': 866, 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'id': 867, 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'synset': 'raft.n.01', 'synonyms': ['raft'], 'id': 868, 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'id': 869, 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'id': 870, 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'id': 871, 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'id': 872, 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'synset': 'rat.n.01', 'synonyms': ['rat'], 'id': 873, 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'id': 874, 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'id': 875, 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'id': 876, 'def': 'vehicle mirror (side or rearview)', 'name': 'rearview_mirror'}, {'frequency': 'c', 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'id': 877, 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'id': 878, 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'c', 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'id': 879, 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'f', 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'id': 880, 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'id': 881, 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'id': 882, 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'id': 883, 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'c', 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'id': 884, 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'synset': 'ring.n.08', 'synonyms': ['ring'], 'id': 885, 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'id': 886, 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'id': 887, 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'synset': 'robe.n.01', 'synonyms': ['robe'], 'id': 888, 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'id': 889, 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'synset': 'rodent.n.01', 'synonyms': ['rodent'], 'id': 890, 'def': 'relatively small placental mammals having a single pair of constantly growing incisor teeth specialized for gnawing', 'name': 'rodent'}, {'frequency': 'r', 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'id': 891, 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'id': 892, 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'id': 893, 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'id': 894, 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'id': 895, 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'id': 896, 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'id': 897, 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'id': 898, 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'id': 899, 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'id': 900, 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'id': 901, 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'id': 902, 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'f', 'synset': 'sail.n.01', 'synonyms': ['sail'], 'id': 903, 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'f', 'synset': 'salad.n.01', 'synonyms': ['salad'], 'id': 904, 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'id': 905, 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'c', 'synset': 'salami.n.01', 'synonyms': ['salami'], 'id': 906, 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'c', 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'id': 907, 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'id': 908, 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'c', 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'id': 909, 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'id': 910, 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'id': 911, 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'id': 912, 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'id': 913, 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'id': 914, 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'id': 915, 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'id': 916, 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'id': 917, 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'id': 918, 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'id': 919, 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'id': 920, 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'id': 921, 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'id': 922, 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'id': 923, 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'f', 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'id': 924, 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'r', 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'id': 925, 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'c', 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'id': 926, 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'f', 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'id': 927, 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'id': 928, 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'c', 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'id': 929, 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'c', 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'id': 930, 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'id': 931, 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'id': 932, 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'c', 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'id': 933, 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'c', 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'id': 934, 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'id': 935, 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'c', 'synset': 'shark.n.01', 'synonyms': ['shark'], 'id': 936, 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'id': 937, 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'id': 938, 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'id': 939, 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'id': 940, 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'id': 941, 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'synset': 'shears.n.01', 'synonyms': ['shears'], 'id': 942, 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'id': 943, 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'id': 944, 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'id': 945, 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'c', 'synset': 'shield.n.02', 'synonyms': ['shield'], 'id': 946, 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'id': 947, 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'id': 948, 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'f', 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'id': 949, 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'id': 950, 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'id': 951, 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'id': 952, 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'f', 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'id': 953, 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'id': 954, 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'id': 955, 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'r', 'synset': 'shower_cap.n.01', 'synonyms': ['shower_cap'], 'id': 956, 'def': 'a tight cap worn to keep hair dry while showering', 'name': 'shower_cap'}, {'frequency': 'f', 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'id': 957, 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'id': 958, 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'f', 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'id': 959, 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'synset': 'silo.n.01', 'synonyms': ['silo'], 'id': 960, 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'synset': 'sink.n.01', 'synonyms': ['sink'], 'id': 961, 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'id': 962, 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'id': 963, 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'synset': 'ski.n.01', 'synonyms': ['ski'], 'id': 964, 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'id': 965, 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'id': 966, 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'id': 967, 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'id': 968, 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'r', 'synset': 'skullcap.n.01', 'synonyms': ['skullcap'], 'id': 969, 'def': 'rounded brimless cap fitting the crown of the head', 'name': 'skullcap'}, {'frequency': 'c', 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'id': 970, 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'id': 971, 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'id': 972, 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'id': 973, 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'id': 974, 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'id': 975, 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'id': 976, 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'id': 977, 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'id': 978, 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'synset': 'soap.n.01', 'synonyms': ['soap'], 'id': 979, 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'id': 980, 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'synset': 'sock.n.01', 'synonyms': ['sock'], 'id': 981, 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'f', 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'id': 982, 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'synset': 'softball.n.01', 'synonyms': ['softball'], 'id': 983, 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'id': 984, 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'id': 985, 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'f', 'synset': 'soup.n.01', 'synonyms': ['soup'], 'id': 986, 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'id': 987, 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'id': 988, 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'id': 989, 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'id': 990, 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'id': 991, 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'id': 992, 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'id': 993, 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'id': 994, 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'id': 995, 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'id': 996, 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'c', 'synset': 'spider.n.01', 'synonyms': ['spider'], 'id': 997, 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'r', 'synset': 'spiny_lobster.n.02', 'synonyms': ['crawfish', 'crayfish'], 'id': 998, 'def': 'large edible marine crustacean having a spiny carapace but lacking the large pincers of true lobsters', 'name': 'crawfish'}, {'frequency': 'c', 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'id': 999, 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'id': 1000, 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'id': 1001, 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'id': 1002, 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'synset': 'squid.n.01', 'synonyms': ['squid_(food)', 'calamari', 'calamary'], 'id': 1003, 'def': '(Italian cuisine) squid prepared as food', 'name': 'squid_(food)'}, {'frequency': 'c', 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'id': 1004, 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'r', 'synset': 'stagecoach.n.01', 'synonyms': ['stagecoach'], 'id': 1005, 'def': 'a large coach-and-four formerly used to carry passengers and mail on regular routes between towns', 'name': 'stagecoach'}, {'frequency': 'c', 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'id': 1006, 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'c', 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'id': 1007, 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'id': 1008, 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'id': 1009, 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'id': 1010, 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'f', 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'id': 1011, 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'id': 1012, 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'id': 1013, 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'id': 1014, 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'synset': 'stew.n.02', 'synonyms': ['stew'], 'id': 1015, 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'id': 1016, 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'id': 1017, 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'f', 'synset': 'stool.n.01', 'synonyms': ['stool'], 'id': 1018, 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'id': 1019, 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'id': 1020, 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'id': 1021, 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'id': 1022, 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'synset': 'strap.n.01', 'synonyms': ['strap'], 'id': 1023, 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'id': 1024, 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'id': 1025, 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'id': 1026, 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'id': 1027, 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'id': 1028, 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'id': 1029, 'def': 'a pointed tool for writing or drawing or engraving, including pens', 'name': 'stylus'}, {'frequency': 'r', 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'id': 1030, 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'id': 1031, 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'id': 1032, 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'f', 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'id': 1033, 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'id': 1034, 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'id': 1035, 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'id': 1036, 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'f', 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'id': 1037, 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'id': 1038, 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'synset': 'swab.n.02', 'synonyms': ['mop'], 'id': 1039, 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'id': 1040, 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'id': 1041, 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'id': 1042, 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'id': 1043, 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'id': 1044, 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'id': 1045, 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'synset': 'sword.n.01', 'synonyms': ['sword'], 'id': 1046, 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'id': 1047, 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'id': 1048, 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'id': 1049, 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'synset': 'table.n.02', 'synonyms': ['table'], 'id': 1050, 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'id': 1051, 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'id': 1052, 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'id': 1053, 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'synset': 'taco.n.02', 'synonyms': ['taco'], 'id': 1054, 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'synset': 'tag.n.02', 'synonyms': ['tag'], 'id': 1055, 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'id': 1056, 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'id': 1057, 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'id': 1058, 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'f', 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'id': 1059, 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'id': 1060, 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'f', 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'id': 1061, 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'id': 1062, 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'id': 1063, 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'id': 1064, 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'id': 1065, 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'id': 1066, 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'c', 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'id': 1067, 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'id': 1068, 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'id': 1069, 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'f', 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'id': 1070, 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'id': 1071, 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'id': 1072, 'def': 'electronic device for communicating by voice over long distances (includes wired and wireless/cell phones)', 'name': 'telephone'}, {'frequency': 'c', 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'id': 1073, 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'id': 1074, 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'id': 1075, 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'id': 1076, 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'id': 1077, 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'id': 1078, 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'id': 1079, 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'id': 1080, 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'id': 1081, 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'id': 1082, 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'f', 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'id': 1083, 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'id': 1084, 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'id': 1085, 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'id': 1086, 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'id': 1087, 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'id': 1088, 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'id': 1089, 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'id': 1090, 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'id': 1091, 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'c', 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'id': 1092, 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'id': 1093, 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'id': 1094, 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'id': 1095, 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'f', 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'id': 1096, 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'id': 1097, 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'id': 1098, 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'id': 1099, 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'f', 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'id': 1100, 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'id': 1101, 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'id': 1102, 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'id': 1103, 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'f', 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'id': 1104, 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'f', 'synset': 'top.n.09', 'synonyms': ['cover'], 'id': 1105, 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'id': 1106, 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'id': 1107, 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'synset': 'towel.n.01', 'synonyms': ['towel'], 'id': 1108, 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'id': 1109, 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'synset': 'toy.n.03', 'synonyms': ['toy'], 'id': 1110, 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'id': 1111, 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'id': 1112, 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'c', 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'id': 1113, 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'f', 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'id': 1114, 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'id': 1115, 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'id': 1116, 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'synset': 'tray.n.01', 'synonyms': ['tray'], 'id': 1117, 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'id': 1118, 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'id': 1119, 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'c', 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'id': 1120, 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'f', 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'id': 1121, 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'id': 1122, 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'synset': 'truck.n.01', 'synonyms': ['truck'], 'id': 1123, 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'id': 1124, 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'id': 1125, 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'synset': 'tub.n.02', 'synonyms': ['vat'], 'id': 1126, 'def': 'a large vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'synset': 'turban.n.01', 'synonyms': ['turban'], 'id': 1127, 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'c', 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'id': 1128, 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'id': 1129, 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'id': 1130, 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'c', 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'id': 1131, 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'c', 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'id': 1132, 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'id': 1133, 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'f', 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'id': 1134, 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'id': 1135, 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'f', 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'id': 1136, 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'c', 'synset': 'urn.n.01', 'synonyms': ['urn'], 'id': 1137, 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'id': 1138, 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'f', 'synset': 'vase.n.01', 'synonyms': ['vase'], 'id': 1139, 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'id': 1140, 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'id': 1141, 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'f', 'synset': 'vest.n.01', 'synonyms': ['vest', 'waistcoat'], 'id': 1142, 'def': "a man's sleeveless garment worn underneath a coat", 'name': 'vest'}, {'frequency': 'c', 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'id': 1143, 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'id': 1144, 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'id': 1145, 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'id': 1146, 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'c', 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'id': 1147, 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'id': 1148, 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'id': 1149, 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'id': 1150, 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'id': 1151, 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'id': 1152, 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'id': 1153, 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'id': 1154, 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'id': 1155, 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'f', 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'id': 1156, 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'id': 1157, 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'id': 1158, 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'synset': 'washbasin.n.01', 'synonyms': ['washbasin', 'basin_(for_washing)', 'washbowl', 'washstand', 'handbasin'], 'id': 1159, 'def': 'a bathroom sink that is permanently installed and connected to a water supply and drainpipe; where you can wash your hands and face', 'name': 'washbasin'}, {'frequency': 'c', 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'id': 1160, 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'id': 1161, 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'id': 1162, 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'id': 1163, 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'id': 1164, 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'id': 1165, 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'c', 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'id': 1166, 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'id': 1167, 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'id': 1168, 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'id': 1169, 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'id': 1170, 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'id': 1171, 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'f', 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'id': 1172, 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'id': 1173, 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'id': 1174, 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'id': 1175, 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'id': 1176, 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'id': 1177, 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'id': 1178, 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'id': 1179, 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'id': 1180, 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'c', 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'id': 1181, 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'c', 'synset': 'wig.n.01', 'synonyms': ['wig'], 'id': 1182, 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'id': 1183, 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'id': 1184, 'def': 'A mill or turbine that is powered by wind', 'name': 'windmill'}, {'frequency': 'c', 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'id': 1185, 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'id': 1186, 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'id': 1187, 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'id': 1188, 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'c', 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'id': 1189, 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'id': 1190, 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'f', 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'id': 1191, 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'synset': 'wok.n.01', 'synonyms': ['wok'], 'id': 1192, 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'id': 1193, 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'id': 1194, 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'id': 1195, 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'id': 1196, 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'f', 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'id': 1197, 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'id': 1198, 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'c', 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'id': 1199, 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'c', 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'id': 1200, 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'c', 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'id': 1201, 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'id': 1202, 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'id': 1203, 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa +# fmt: on diff --git a/data_processing/detectron2/detectron2/data/datasets/lvis_v1_category_image_count.py b/data_processing/detectron2/detectron2/data/datasets/lvis_v1_category_image_count.py new file mode 100644 index 0000000..31bf0cf --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/lvis_v1_category_image_count.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v1_train.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["name"] +# del x["instance_count"] +# del x["def"] +# del x["synonyms"] +# del x["frequency"] +# del x["synset"] +# LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa" +# with open("/tmp/lvis_category_image_count.py", "wt") as f: +# f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}") +# Then paste the contents of that file below + +# fmt: off +LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa +# fmt: on diff --git a/data_processing/detectron2/detectron2/data/datasets/pascal_voc.py b/data_processing/detectron2/detectron2/data/datasets/pascal_voc.py new file mode 100644 index 0000000..dbbf82c --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/pascal_voc.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import numpy as np +import os +import xml.etree.ElementTree as ET +from typing import List, Tuple, Union + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + +__all__ = ["load_voc_instances", "register_pascal_voc"] + + +# fmt: off +CLASS_NAMES = ( + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", + "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", + "pottedplant", "sheep", "sofa", "train", "tvmonitor" +) +# fmt: on + + +def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): + """ + Load Pascal VOC detection annotations to Detectron2 format. + + Args: + dirname: Contain "Annotations", "ImageSets", "JPEGImages" + split (str): one of "train", "test", "val", "trainval" + class_names: list or tuple of class names + """ + with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: + fileids = np.loadtxt(f, dtype=np.str) + + # Needs to read many small annotation files. Makes sense at local + annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) + dicts = [] + for fileid in fileids: + anno_file = os.path.join(annotation_dirname, fileid + ".xml") + jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") + + with PathManager.open(anno_file) as f: + tree = ET.parse(f) + + r = { + "file_name": jpeg_file, + "image_id": fileid, + "height": int(tree.findall("./size/height")[0].text), + "width": int(tree.findall("./size/width")[0].text), + } + instances = [] + + for obj in tree.findall("object"): + cls = obj.find("name").text + # We include "difficult" samples in training. + # Based on limited experiments, they don't hurt accuracy. + # difficult = int(obj.find("difficult").text) + # if difficult == 1: + # continue + bbox = obj.find("bndbox") + bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] + # Original annotations are integers in the range [1, W or H] + # Assuming they mean 1-based pixel indices (inclusive), + # a box with annotation (xmin=1, xmax=W) covers the whole image. + # In coordinate space this is represented by (xmin=0, xmax=W) + bbox[0] -= 1.0 + bbox[1] -= 1.0 + instances.append( + {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} + ) + r["annotations"] = instances + dicts.append(r) + return dicts + + +def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES): + DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names)) + MetadataCatalog.get(name).set( + thing_classes=list(class_names), dirname=dirname, year=year, split=split + ) diff --git a/data_processing/detectron2/detectron2/data/datasets/register_coco.py b/data_processing/detectron2/detectron2/data/datasets/register_coco.py new file mode 100644 index 0000000..e564438 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/datasets/register_coco.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .coco import register_coco_instances # noqa +from .coco_panoptic import register_coco_panoptic_separated # noqa diff --git a/data_processing/detectron2/detectron2/data/detection_utils.py b/data_processing/detectron2/detectron2/data/detection_utils.py new file mode 100644 index 0000000..ada19bd --- /dev/null +++ b/data_processing/detectron2/detectron2/data/detection_utils.py @@ -0,0 +1,659 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +Common data processing utilities that are used in a +typical object detection data pipeline. +""" +import logging +import numpy as np +from typing import List, Union +import pycocotools.mask as mask_util +import torch +from PIL import Image + +from detectron2.structures import ( + BitMasks, + Boxes, + BoxMode, + Instances, + Keypoints, + PolygonMasks, + RotatedBoxes, + polygons_to_bitmask, +) +from detectron2.utils.file_io import PathManager + +from . import transforms as T +from .catalog import MetadataCatalog + +__all__ = [ + "SizeMismatchError", + "convert_image_to_rgb", + "check_image_size", + "transform_proposals", + "transform_instance_annotations", + "annotations_to_instances", + "annotations_to_instances_rotated", + "build_augmentation", + "build_transform_gen", + "create_keypoint_hflip_indices", + "filter_empty_instances", + "read_image", +] + + +class SizeMismatchError(ValueError): + """ + When loaded image has difference width/height compared with annotation. + """ + + +# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601 +_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]] +_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]] + +# https://www.exiv2.org/tags.html +_EXIF_ORIENT = 274 # exif 'Orientation' tag + + +def convert_PIL_to_numpy(image, format): + """ + Convert PIL image to numpy array of target format. + + Args: + image (PIL.Image): a PIL image + format (str): the format of output image + + Returns: + (np.ndarray): also see `read_image` + """ + if format is not None: + # PIL only supports RGB, so convert to RGB and flip channels over below + conversion_format = format + if format in ["BGR", "YUV-BT.601"]: + conversion_format = "RGB" + image = image.convert(conversion_format) + image = np.asarray(image) + # PIL squeezes out the channel dimension for "L", so make it HWC + if format == "L": + image = np.expand_dims(image, -1) + + # handle formats not supported by PIL + elif format == "BGR": + # flip channels if needed + image = image[:, :, ::-1] + elif format == "YUV-BT.601": + image = image / 255.0 + image = np.dot(image, np.array(_M_RGB2YUV).T) + + return image + + +def convert_image_to_rgb(image, format): + """ + Convert an image from given format to RGB. + + Args: + image (np.ndarray or Tensor): an HWC image + format (str): the format of input image, also see `read_image` + + Returns: + (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8 + """ + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + if format == "BGR": + image = image[:, :, [2, 1, 0]] + elif format == "YUV-BT.601": + image = np.dot(image, np.array(_M_YUV2RGB).T) + image = image * 255.0 + else: + if format == "L": + image = image[:, :, 0] + image = image.astype(np.uint8) + image = np.asarray(Image.fromarray(image, mode=format).convert("RGB")) + return image + + +def _apply_exif_orientation(image): + """ + Applies the exif orientation correctly. + + This code exists per the bug: + https://github.com/python-pillow/Pillow/issues/3973 + with the function `ImageOps.exif_transpose`. The Pillow source raises errors with + various methods, especially `tobytes` + + Function based on: + https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59 + https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527 + + Args: + image (PIL.Image): a PIL image + + Returns: + (PIL.Image): the PIL image with exif orientation applied, if applicable + """ + if not hasattr(image, "getexif"): + return image + + try: + exif = image.getexif() + except Exception: # https://github.com/facebookresearch/detectron2/issues/1885 + exif = None + + if exif is None: + return image + + orientation = exif.get(_EXIF_ORIENT) + + method = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90, + }.get(orientation) + + if method is not None: + return image.transpose(method) + return image + + +def read_image(file_name, format=None): + """ + Read an image into the given format. + Will apply rotation and flipping if the image has such exif information. + + Args: + file_name (str): image file path + format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601". + + Returns: + image (np.ndarray): + an HWC image in the given format, which is 0-255, uint8 for + supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601. + """ + with PathManager.open(file_name, "rb") as f: + image = Image.open(f) + + # work around this bug: https://github.com/python-pillow/Pillow/issues/3973 + image = _apply_exif_orientation(image) + return convert_PIL_to_numpy(image, format) + + +def check_image_size(dataset_dict, image): + """ + Raise an error if the image does not match the size specified in the dict. + """ + if "width" in dataset_dict or "height" in dataset_dict: + image_wh = (image.shape[1], image.shape[0]) + expected_wh = (dataset_dict["width"], dataset_dict["height"]) + if not image_wh == expected_wh: + raise SizeMismatchError( + "Mismatched image shape{}, got {}, expect {}.".format( + " for image " + dataset_dict["file_name"] + if "file_name" in dataset_dict + else "", + image_wh, + expected_wh, + ) + + " Please check the width/height in your annotation." + ) + + # To ensure bbox always remap to original image size + if "width" not in dataset_dict: + dataset_dict["width"] = image.shape[1] + if "height" not in dataset_dict: + dataset_dict["height"] = image.shape[0] + + +def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0): + """ + Apply transformations to the proposals in dataset_dict, if any. + + Args: + dataset_dict (dict): a dict read from the dataset, possibly + contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode" + image_shape (tuple): height, width + transforms (TransformList): + proposal_topk (int): only keep top-K scoring proposals + min_box_size (int): proposals with either side smaller than this + threshold are removed + + The input dict is modified in-place, with abovementioned keys removed. A new + key "proposals" will be added. Its value is an `Instances` + object which contains the transformed proposals in its field + "proposal_boxes" and "objectness_logits". + """ + if "proposal_boxes" in dataset_dict: + # Transform proposal boxes + boxes = transforms.apply_box( + BoxMode.convert( + dataset_dict.pop("proposal_boxes"), + dataset_dict.pop("proposal_bbox_mode"), + BoxMode.XYXY_ABS, + ) + ) + boxes = Boxes(boxes) + objectness_logits = torch.as_tensor( + dataset_dict.pop("proposal_objectness_logits").astype("float32") + ) + + boxes.clip(image_shape) + keep = boxes.nonempty(threshold=min_box_size) + boxes = boxes[keep] + objectness_logits = objectness_logits[keep] + + proposals = Instances(image_shape) + proposals.proposal_boxes = boxes[:proposal_topk] + proposals.objectness_logits = objectness_logits[:proposal_topk] + dataset_dict["proposals"] = proposals + + +def get_bbox(annotation): + """ + Get bbox from data + Args: + annotation (dict): dict of instance annotations for a single instance. + Returns: + bbox (ndarray): x1, y1, x2, y2 coordinates + """ + # bbox is 1d (per-instance bounding box) + bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) + return bbox + + +def transform_instance_annotations( + annotation, transforms, image_size, *, keypoint_hflip_indices=None +): + """ + Apply transforms to box, segmentation and keypoints annotations of a single instance. + + It will use `transforms.apply_box` for the box, and + `transforms.apply_coords` for segmentation polygons & keypoints. + If you need anything more specially designed for each data structure, + you'll need to implement your own version of this function or the transforms. + + Args: + annotation (dict): dict of instance annotations for a single instance. + It will be modified in-place. + transforms (TransformList or list[Transform]): + image_size (tuple): the height, width of the transformed image + keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. + + Returns: + dict: + the same input dict with fields "bbox", "segmentation", "keypoints" + transformed according to `transforms`. + The "bbox_mode" field will be set to XYXY_ABS. + """ + if isinstance(transforms, (tuple, list)): + transforms = T.TransformList(transforms) + # bbox is 1d (per-instance bounding box) + bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) + # clip transformed bbox to image size + bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0) + annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1]) + annotation["bbox_mode"] = BoxMode.XYXY_ABS + + if "segmentation" in annotation: + # each instance contains 1 or more polygons + segm = annotation["segmentation"] + if isinstance(segm, list): + # polygons + polygons = [np.asarray(p).reshape(-1, 2) for p in segm] + annotation["segmentation"] = [ + p.reshape(-1) for p in transforms.apply_polygons(polygons) + ] + elif isinstance(segm, dict): + # RLE + mask = mask_util.decode(segm) + mask = transforms.apply_segmentation(mask) + assert tuple(mask.shape[:2]) == image_size + annotation["segmentation"] = mask + else: + raise ValueError( + "Cannot transform segmentation of type '{}'!" + "Supported types are: polygons as list[list[float] or ndarray]," + " COCO-style RLE as a dict.".format(type(segm)) + ) + + if "keypoints" in annotation: + keypoints = transform_keypoint_annotations( + annotation["keypoints"], transforms, image_size, keypoint_hflip_indices + ) + annotation["keypoints"] = keypoints + + return annotation + + +def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None): + """ + Transform keypoint annotations of an image. + If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0) + + Args: + keypoints (list[float]): Nx3 float in Detectron2's Dataset format. + Each point is represented by (x, y, visibility). + transforms (TransformList): + image_size (tuple): the height, width of the transformed image + keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. + When `transforms` includes horizontal flip, will use the index + mapping to flip keypoints. + """ + # (N*3,) -> (N, 3) + keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3) + keypoints_xy = transforms.apply_coords(keypoints[:, :2]) + + # Set all out-of-boundary points to "unlabeled" + inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1])) + inside = inside.all(axis=1) + keypoints[:, :2] = keypoints_xy + keypoints[:, 2][~inside] = 0 + + # This assumes that HorizFlipTransform is the only one that does flip + do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + + # Alternative way: check if probe points was horizontally flipped. + # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]]) + # probe_aug = transforms.apply_coords(probe.copy()) + # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa + + # If flipped, swap each keypoint with its opposite-handed equivalent + if do_hflip: + if keypoint_hflip_indices is None: + raise ValueError("Cannot flip keypoints without providing flip indices!") + if len(keypoints) != len(keypoint_hflip_indices): + raise ValueError( + "Keypoint data has {} points, but metadata " + "contains {} points!".format(len(keypoints), len(keypoint_hflip_indices)) + ) + keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :] + + # Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0 + keypoints[keypoints[:, 2] == 0] = 0 + return keypoints + + +def annotations_to_instances(annos, image_size, mask_format="polygon"): + """ + Create an :class:`Instances` object used by the models, + from instance annotations in the dataset dict. + + Args: + annos (list[dict]): a list of instance annotations in one image, each + element for one instance. + image_size (tuple): height, width + + Returns: + Instances: + It will contain fields "gt_boxes", "gt_classes", + "gt_masks", "gt_keypoints", if they can be obtained from `annos`. + This is the format that builtin models expect. + """ + boxes = ( + np.stack( + [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos] + ) + if len(annos) + else np.zeros((0, 4)) + ) + target = Instances(image_size) + target.gt_boxes = Boxes(boxes) + + classes = [int(obj["category_id"]) for obj in annos] + classes = torch.tensor(classes, dtype=torch.int64) + target.gt_classes = classes + + if len(annos) and "segmentation" in annos[0]: + segms = [obj["segmentation"] for obj in annos] + if mask_format == "polygon": + try: + masks = PolygonMasks(segms) + except ValueError as e: + raise ValueError( + "Failed to use mask_format=='polygon' from the given annotations!" + ) from e + else: + assert mask_format == "bitmask", mask_format + masks = [] + for segm in segms: + if isinstance(segm, list): + # polygon + masks.append(polygons_to_bitmask(segm, *image_size)) + elif isinstance(segm, dict): + # COCO RLE + masks.append(mask_util.decode(segm)) + elif isinstance(segm, np.ndarray): + assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format( + segm.ndim + ) + # mask array + masks.append(segm) + else: + raise ValueError( + "Cannot convert segmentation of type '{}' to BitMasks!" + "Supported types are: polygons as list[list[float] or ndarray]," + " COCO-style RLE as a dict, or a binary segmentation mask " + " in a 2D numpy array of shape HxW.".format(type(segm)) + ) + # torch.from_numpy does not support array with negative stride. + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks]) + ) + target.gt_masks = masks + + if len(annos) and "keypoints" in annos[0]: + kpts = [obj.get("keypoints", []) for obj in annos] + target.gt_keypoints = Keypoints(kpts) + + return target + + +def annotations_to_instances_rotated(annos, image_size): + """ + Create an :class:`Instances` object used by the models, + from instance annotations in the dataset dict. + Compared to `annotations_to_instances`, this function is for rotated boxes only + + Args: + annos (list[dict]): a list of instance annotations in one image, each + element for one instance. + image_size (tuple): height, width + + Returns: + Instances: + Containing fields "gt_boxes", "gt_classes", + if they can be obtained from `annos`. + This is the format that builtin models expect. + """ + boxes = [obj["bbox"] for obj in annos] + target = Instances(image_size) + boxes = target.gt_boxes = RotatedBoxes(boxes) + boxes.clip(image_size) + + classes = [obj["category_id"] for obj in annos] + classes = torch.tensor(classes, dtype=torch.int64) + target.gt_classes = classes + + return target + + +def filter_empty_instances( + instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False +): + """ + Filter out empty instances in an `Instances` object. + + Args: + instances (Instances): + by_box (bool): whether to filter out instances with empty boxes + by_mask (bool): whether to filter out instances with empty masks + box_threshold (float): minimum width and height to be considered non-empty + return_mask (bool): whether to return boolean mask of filtered instances + + Returns: + Instances: the filtered instances. + tensor[bool], optional: boolean mask of filtered instances + """ + assert by_box or by_mask + r = [] + if by_box: + r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) + if instances.has("gt_masks") and by_mask: + r.append(instances.gt_masks.nonempty()) + + # TODO: can also filter visible keypoints + + if not r: + return instances + m = r[0] + for x in r[1:]: + m = m & x + if return_mask: + return instances[m], m + return instances[m] + + +def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]: + """ + Args: + dataset_names: list of dataset names + + Returns: + list[int]: a list of size=#keypoints, storing the + horizontally-flipped keypoint indices. + """ + if isinstance(dataset_names, str): + dataset_names = [dataset_names] + + check_metadata_consistency("keypoint_names", dataset_names) + check_metadata_consistency("keypoint_flip_map", dataset_names) + + meta = MetadataCatalog.get(dataset_names[0]) + names = meta.keypoint_names + # TODO flip -> hflip + flip_map = dict(meta.keypoint_flip_map) + flip_map.update({v: k for k, v in flip_map.items()}) + flipped_names = [i if i not in flip_map else flip_map[i] for i in names] + flip_indices = [names.index(i) for i in flipped_names] + return flip_indices + + +def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0): + """ + Get frequency weight for each class sorted by class id. + We now calcualte freqency weight using image_count to the power freq_weight_power. + + Args: + dataset_names: list of dataset names + freq_weight_power: power value + """ + if isinstance(dataset_names, str): + dataset_names = [dataset_names] + + check_metadata_consistency("class_image_count", dataset_names) + + meta = MetadataCatalog.get(dataset_names[0]) + class_freq_meta = meta.class_image_count + class_freq = torch.tensor( + [c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])] + ) + class_freq_weight = class_freq.float() ** freq_weight_power + return class_freq_weight + + +def gen_crop_transform_with_instance(crop_size, image_size, instance): + """ + Generate a CropTransform so that the cropping region contains + the center of the given instance. + + Args: + crop_size (tuple): h, w in pixels + image_size (tuple): h, w + instance (dict): an annotation dict of one instance, in Detectron2's + dataset format. + """ + crop_size = np.asarray(crop_size, dtype=np.int32) + bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) + center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 + assert ( + image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] + ), "The annotation bounding box is outside of the image!" + assert ( + image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] + ), "Crop size is larger than image size!" + + min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) + max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) + max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) + + y0 = np.random.randint(min_yx[0], max_yx[0] + 1) + x0 = np.random.randint(min_yx[1], max_yx[1] + 1) + return T.CropTransform(x0, y0, crop_size[1], crop_size[0]) + + +def check_metadata_consistency(key, dataset_names): + """ + Check that the datasets have consistent metadata. + + Args: + key (str): a metadata key + dataset_names (list[str]): a list of dataset names + + Raises: + AttributeError: if the key does not exist in the metadata + ValueError: if the given datasets do not have the same metadata values defined by key + """ + if len(dataset_names) == 0: + return + logger = logging.getLogger(__name__) + entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names] + for idx, entry in enumerate(entries_per_dataset): + if entry != entries_per_dataset[0]: + logger.error( + "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry)) + ) + logger.error( + "Metadata '{}' for dataset '{}' is '{}'".format( + key, dataset_names[0], str(entries_per_dataset[0]) + ) + ) + raise ValueError("Datasets have different metadata '{}'!".format(key)) + + +def build_augmentation(cfg, is_train): + """ + Create a list of default :class:`Augmentation` from config. + Now it includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + if is_train and cfg.INPUT.RANDOM_FLIP != "none": + augmentation.append( + T.RandomFlip( + horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", + vertical=cfg.INPUT.RANDOM_FLIP == "vertical", + ) + ) + return augmentation + + +build_transform_gen = build_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/data_processing/detectron2/detectron2/data/samplers/__init__.py b/data_processing/detectron2/detectron2/data/samplers/__init__.py new file mode 100644 index 0000000..85c9f1a --- /dev/null +++ b/data_processing/detectron2/detectron2/data/samplers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .distributed_sampler import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) + +from .grouped_batch_sampler import GroupedBatchSampler + +__all__ = [ + "GroupedBatchSampler", + "TrainingSampler", + "RandomSubsetTrainingSampler", + "InferenceSampler", + "RepeatFactorTrainingSampler", +] diff --git a/data_processing/detectron2/detectron2/data/samplers/distributed_sampler.py b/data_processing/detectron2/detectron2/data/samplers/distributed_sampler.py new file mode 100644 index 0000000..a098e6a --- /dev/null +++ b/data_processing/detectron2/detectron2/data/samplers/distributed_sampler.py @@ -0,0 +1,278 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import math +from collections import defaultdict +from typing import Optional +import torch +from torch.utils.data.sampler import Sampler + +from detectron2.utils import comm + +logger = logging.getLogger(__name__) + + +class TrainingSampler(Sampler): + """ + In training, we only care about the "infinite stream" of training data. + So this sampler produces an infinite stream of indices and + all workers cooperate to correctly shuffle the indices and sample different indices. + + The samplers in each worker effectively produces `indices[worker_id::num_workers]` + where `indices` is an infinite stream of indices consisting of + `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) + or `range(size) + range(size) + ...` (if shuffle is False) + + Note that this sampler does not shard based on pytorch DataLoader worker id. + A sampler passed to pytorch DataLoader is used only with map-style dataset + and will not be executed inside workers. + But if this sampler is used in a way that it gets execute inside a dataloader + worker, then extra work needs to be done to shard its outputs based on worker id. + This is required so that workers don't produce identical data. + :class:`ToIterableDataset` implements this logic. + This note is true for all samplers in detectron2. + """ + + def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + if not isinstance(size, int): + raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.") + if size <= 0: + raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.") + self._size = size + self._shuffle = shuffle + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + if self._shuffle: + yield from torch.randperm(self._size, generator=g).tolist() + else: + yield from torch.arange(self._size).tolist() + + +class RandomSubsetTrainingSampler(TrainingSampler): + """ + Similar to TrainingSampler, but only sample a random subset of indices. + This is useful when you want to estimate the accuracy vs data-number curves by + training the model with different subset_ratio. + """ + + def __init__( + self, + size: int, + subset_ratio: float, + shuffle: bool = True, + seed_shuffle: Optional[int] = None, + seed_subset: Optional[int] = None, + ): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + subset_ratio (float): the ratio of subset data to sample from the underlying dataset + shuffle (bool): whether to shuffle the indices or not + seed_shuffle (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + seed_subset (int): the seed to randomize the subset to be sampled. + Must be the same across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle) + + assert 0.0 < subset_ratio <= 1.0 + self._size_subset = int(size * subset_ratio) + assert self._size_subset > 0 + if seed_subset is None: + seed_subset = comm.shared_random_seed() + self._seed_subset = int(seed_subset) + + # randomly generate the subset indexes to be sampled from + g = torch.Generator() + g.manual_seed(self._seed_subset) + indexes_randperm = torch.randperm(self._size, generator=g) + self._indexes_subset = indexes_randperm[: self._size_subset] + + logger.info("Using RandomSubsetTrainingSampler......") + logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data") + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__() + while True: + if self._shuffle: + # generate a random permutation to shuffle self._indexes_subset + randperm = torch.randperm(self._size_subset, generator=g) + yield from self._indexes_subset[randperm].tolist() + else: + yield from self._indexes_subset.tolist() + + +class RepeatFactorTrainingSampler(Sampler): + """ + Similar to TrainingSampler, but a sample may appear more times than others based + on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS. + """ + + def __init__(self, repeat_factors, *, shuffle=True, seed=None): + """ + Args: + repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's + full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``. + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self._shuffle = shuffle + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + # Split into whole number (_int_part) and fractional (_frac_part) parts. + self._int_part = torch.trunc(repeat_factors) + self._frac_part = repeat_factors - self._int_part + + @staticmethod + def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): + """ + Compute (fractional) per-image repeat factors based on category frequency. + The repeat factor for an image is a function of the frequency of the rarest + category labeled in that image. The "frequency of category c" in [0, 1] is defined + as the fraction of images in the training set (without repeats) in which category c + appears. + See :paper:`lvis` (>= v2) Appendix B.2. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 dataset format. + repeat_thresh (float): frequency threshold below which data is repeated. + If the frequency is half of `repeat_thresh`, the image will be + repeated twice. + + Returns: + torch.Tensor: + the i-th element is the repeat factor for the dataset image at index i. + """ + # 1. For each category c, compute the fraction of images that contain it: f(c) + category_freq = defaultdict(int) + for dataset_dict in dataset_dicts: # For each image (without repeats) + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + for cat_id in cat_ids: + category_freq[cat_id] += 1 + num_images = len(dataset_dicts) + for k, v in category_freq.items(): + category_freq[k] = v / num_images + + # 2. For each category c, compute the category-level repeat factor: + # r(c) = max(1, sqrt(t / f(c))) + category_rep = { + cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + # 3. For each image I, compute the image-level repeat factor: + # r(I) = max_{c in I} r(c) + rep_factors = [] + for dataset_dict in dataset_dicts: + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) + rep_factors.append(rep_factor) + + return torch.tensor(rep_factors, dtype=torch.float32) + + def _get_epoch_indices(self, generator): + """ + Create a list of dataset indices (with repeats) to use for one epoch. + + Args: + generator (torch.Generator): pseudo random number generator used for + stochastic rounding. + + Returns: + torch.Tensor: list of dataset indices to use in one epoch. Each index + is repeated based on its calculated repeat factor. + """ + # Since repeat factors are fractional, we use stochastic rounding so + # that the target repeat factor is achieved in expectation over the + # course of training + rands = torch.rand(len(self._frac_part), generator=generator) + rep_factors = self._int_part + (rands < self._frac_part).float() + # Construct a list of indices in which we repeat images as specified + indices = [] + for dataset_index, rep_factor in enumerate(rep_factors): + indices.extend([dataset_index] * int(rep_factor.item())) + return torch.tensor(indices, dtype=torch.int64) + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + # Sample indices with repeats determined by stochastic rounding; each + # "epoch" may have a slightly different size due to the rounding. + indices = self._get_epoch_indices(g) + if self._shuffle: + randperm = torch.randperm(len(indices), generator=g) + yield from indices[randperm].tolist() + else: + yield from indices.tolist() + + +class InferenceSampler(Sampler): + """ + Produce indices for inference across all workers. + Inference needs to run on the __exact__ set of samples, + therefore when the total number of samples is not divisible by the number of workers, + this sampler produces different number of samples on different workers. + """ + + def __init__(self, size: int): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + """ + self._size = size + assert size > 0 + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + self._local_indices = self._get_local_indices(size, self._world_size, self._rank) + + @staticmethod + def _get_local_indices(total_size, world_size, rank): + shard_size = total_size // world_size + left = total_size % world_size + shard_sizes = [shard_size + int(r < left) for r in range(world_size)] + + begin = sum(shard_sizes[:rank]) + end = min(sum(shard_sizes[: rank + 1]), total_size) + return range(begin, end) + + def __iter__(self): + yield from self._local_indices + + def __len__(self): + return len(self._local_indices) diff --git a/data_processing/detectron2/detectron2/data/samplers/grouped_batch_sampler.py b/data_processing/detectron2/detectron2/data/samplers/grouped_batch_sampler.py new file mode 100644 index 0000000..5b24773 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/samplers/grouped_batch_sampler.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from torch.utils.data.sampler import BatchSampler, Sampler + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that the batch only contain elements from the same group. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + """ + + def __init__(self, sampler, group_ids, batch_size): + """ + Args: + sampler (Sampler): Base sampler. + group_ids (list[int]): If the sampler produces indices in range [0, N), + `group_ids` must be a list of `N` ints which contains the group id of each sample. + The group ids must be a set of integers in the range [0, num_groups). + batch_size (int): Size of mini-batch. + """ + if not isinstance(sampler, Sampler): + raise ValueError( + "sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = np.asarray(group_ids) + assert self.group_ids.ndim == 1 + self.batch_size = batch_size + groups = np.unique(self.group_ids).tolist() + + # buffer the indices of each group until batch size is reached + self.buffer_per_group = {k: [] for k in groups} + + def __iter__(self): + for idx in self.sampler: + group_id = self.group_ids[idx] + group_buffer = self.buffer_per_group[group_id] + group_buffer.append(idx) + if len(group_buffer) == self.batch_size: + yield group_buffer[:] # yield a copy of the list + del group_buffer[:] + + def __len__(self): + raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.") diff --git a/data_processing/detectron2/detectron2/data/transforms/__init__.py b/data_processing/detectron2/detectron2/data/transforms/__init__.py new file mode 100644 index 0000000..ab3c63b --- /dev/null +++ b/data_processing/detectron2/detectron2/data/transforms/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from fvcore.transforms.transform import Transform, TransformList # order them first +from fvcore.transforms.transform import * +from .transform import * +from .augmentation import * +from .augmentation_impl import * + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +from detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/data_processing/detectron2/detectron2/data/transforms/augmentation.py b/data_processing/detectron2/detectron2/data/transforms/augmentation.py new file mode 100644 index 0000000..63dd41a --- /dev/null +++ b/data_processing/detectron2/detectron2/data/transforms/augmentation.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import inspect +import numpy as np +import pprint +from typing import Any, List, Optional, Tuple, Union +from fvcore.transforms.transform import Transform, TransformList + +""" +See "Data Augmentation" tutorial for an overview of the system: +https://detectron2.readthedocs.io/tutorials/augmentation.html +""" + + +__all__ = [ + "Augmentation", + "AugmentationList", + "AugInput", + "TransformGen", + "apply_transform_gens", + "StandardAugInput", + "apply_augmentations", +] + + +def _check_img_dtype(img): + assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format( + type(img) + ) + assert not isinstance(img.dtype, np.integer) or ( + img.dtype == np.uint8 + ), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format( + img.dtype + ) + assert img.ndim in [2, 3], img.ndim + + +def _get_aug_input_args(aug, aug_input) -> List[Any]: + """ + Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``. + """ + if aug.input_args is None: + # Decide what attributes are needed automatically + prms = list(inspect.signature(aug.get_transform).parameters.items()) + # The default behavior is: if there is one parameter, then its "image" + # (work automatically for majority of use cases, and also avoid BC breaking), + # Otherwise, use the argument names. + if len(prms) == 1: + names = ("image",) + else: + names = [] + for name, prm in prms: + if prm.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + raise TypeError( + f""" \ +The default implementation of `{type(aug)}.__call__` does not allow \ +`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \ +If arguments are unknown, reimplement `__call__` instead. \ +""" + ) + names.append(name) + aug.input_args = tuple(names) + + args = [] + for f in aug.input_args: + try: + args.append(getattr(aug_input, f)) + except AttributeError as e: + raise AttributeError( + f"{type(aug)}.get_transform needs input attribute '{f}', " + f"but it is not an attribute of {type(aug_input)}!" + ) from e + return args + + +class Augmentation: + """ + Augmentation defines (often random) policies/strategies to generate :class:`Transform` + from data. It is often used for pre-processing of input data. + + A "policy" that generates a :class:`Transform` may, in the most general case, + need arbitrary information from input data in order to determine what transforms + to apply. Therefore, each :class:`Augmentation` instance defines the arguments + needed by its :meth:`get_transform` method. When called with the positional arguments, + the :meth:`get_transform` method executes the policy. + + Note that :class:`Augmentation` defines the policies to create a :class:`Transform`, + but not how to execute the actual transform operations to those data. + Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform. + + The returned `Transform` object is meant to describe deterministic transformation, which means + it can be re-applied on associated data, e.g. the geometry of an image and its segmentation + masks need to be transformed together. + (If such re-application is not needed, then determinism is not a crucial requirement.) + """ + + input_args: Optional[Tuple[str]] = None + """ + Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``. + By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only + contain "image". As long as the argument name convention is followed, there is no need for + users to touch this attribute. + """ + + def _init(self, params=None): + if params: + for k, v in params.items(): + if k != "self" and not k.startswith("_"): + setattr(self, k, v) + + def get_transform(self, *args) -> Transform: + """ + Execute the policy based on input data, and decide what transform to apply to inputs. + + Args: + args: Any fixed-length positional arguments. By default, the name of the arguments + should exist in the :class:`AugInput` to be used. + + Returns: + Transform: Returns the deterministic transform to apply to the input. + + Examples: + :: + class MyAug: + # if a policy needs to know both image and semantic segmentation + def get_transform(image, sem_seg) -> T.Transform: + pass + tfm: Transform = MyAug().get_transform(image, sem_seg) + new_image = tfm.apply_image(image) + + Notes: + Users can freely use arbitrary new argument names in custom + :meth:`get_transform` method, as long as they are available in the + input data. In detectron2 we use the following convention: + + * image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or + floating point in range [0, 1] or [0, 255]. + * boxes: (N,4) ndarray of float32. It represents the instance bounding boxes + of N instances. Each is in XYXY format in unit of absolute coordinates. + * sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel. + + We do not specify convention for other types and do not include builtin + :class:`Augmentation` that uses other types in detectron2. + """ + raise NotImplementedError + + def __call__(self, aug_input) -> Transform: + """ + Augment the given `aug_input` **in-place**, and return the transform that's used. + + This method will be called to apply the augmentation. In most augmentation, it + is enough to use the default implementation, which calls :meth:`get_transform` + using the inputs. But a subclass can overwrite it to have more complicated logic. + + Args: + aug_input (AugInput): an object that has attributes needed by this augmentation + (defined by ``self.get_transform``). Its ``transform`` method will be called + to in-place transform it. + + Returns: + Transform: the transform that is applied on the input. + """ + args = _get_aug_input_args(self, aug_input) + tfm = self.get_transform(*args) + assert isinstance(tfm, (Transform, TransformList)), ( + f"{type(self)}.get_transform must return an instance of Transform! " + f"Got {type(tfm)} instead." + ) + aug_input.transform(tfm) + return tfm + + def _rand_range(self, low=1.0, high=None, size=None): + """ + Uniform float random number between low and high. + """ + if high is None: + low, high = 0, low + if size is None: + size = [] + return np.random.uniform(low, high, size) + + def __repr__(self): + """ + Produce something like: + "MyAugmentation(field1={self.field1}, field2={self.field2})" + """ + try: + sig = inspect.signature(self.__init__) + classname = type(self).__name__ + argstr = [] + for name, param in sig.parameters.items(): + assert ( + param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD + ), "The default __repr__ doesn't support *args or **kwargs" + assert hasattr(self, name), ( + "Attribute {} not found! " + "Default __repr__ only works if attributes match the constructor.".format(name) + ) + attr = getattr(self, name) + default = param.default + if default is attr: + continue + attr_str = pprint.pformat(attr) + if "\n" in attr_str: + # don't show it if pformat decides to use >1 lines + attr_str = "..." + argstr.append("{}={}".format(name, attr_str)) + return "{}({})".format(classname, ", ".join(argstr)) + except AssertionError: + return super().__repr__() + + __str__ = __repr__ + + +class _TransformToAug(Augmentation): + def __init__(self, tfm: Transform): + self.tfm = tfm + + def get_transform(self, *args): + return self.tfm + + def __repr__(self): + return repr(self.tfm) + + __str__ = __repr__ + + +def _transform_to_aug(tfm_or_aug): + """ + Wrap Transform into Augmentation. + Private, used internally to implement augmentations. + """ + assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug + if isinstance(tfm_or_aug, Augmentation): + return tfm_or_aug + else: + return _TransformToAug(tfm_or_aug) + + +class AugmentationList(Augmentation): + """ + Apply a sequence of augmentations. + + It has ``__call__`` method to apply the augmentations. + + Note that :meth:`get_transform` method is impossible (will throw error if called) + for :class:`AugmentationList`, because in order to apply a sequence of augmentations, + the kth augmentation must be applied first, to provide inputs needed by the (k+1)th + augmentation. + """ + + def __init__(self, augs): + """ + Args: + augs (list[Augmentation or Transform]): + """ + super().__init__() + self.augs = [_transform_to_aug(x) for x in augs] + + def __call__(self, aug_input) -> TransformList: + tfms = [] + for x in self.augs: + tfm = x(aug_input) + tfms.append(tfm) + return TransformList(tfms) + + def __repr__(self): + msgs = [str(x) for x in self.augs] + return "AugmentationList[{}]".format(", ".join(msgs)) + + __str__ = __repr__ + + +class AugInput: + """ + Input that can be used with :meth:`Augmentation.__call__`. + This is a standard implementation for the majority of use cases. + This class provides the standard attributes **"image", "boxes", "sem_seg"** + defined in :meth:`__init__` and they may be needed by different augmentations. + Most augmentation policies do not need attributes beyond these three. + + After applying augmentations to these attributes (using :meth:`AugInput.transform`), + the returned transforms can then be used to transform other data structures that users have. + + Examples: + :: + input = AugInput(image, boxes=boxes) + tfms = augmentation(input) + transformed_image = input.image + transformed_boxes = input.boxes + transformed_other_data = tfms.apply_other(other_data) + + An extended project that works with new data types may implement augmentation policies + that need other inputs. An algorithm may need to transform inputs in a way different + from the standard approach defined in this class. In those rare situations, users can + implement a class similar to this class, that satify the following condition: + + * The input must provide access to these data in the form of attribute access + (``getattr``). For example, if an :class:`Augmentation` to be applied needs "image" + and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg". + * The input must have a ``transform(tfm: Transform) -> None`` method which + in-place transforms all its attributes. + """ + + # TODO maybe should support more builtin data types here + def __init__( + self, + image: np.ndarray, + *, + boxes: Optional[np.ndarray] = None, + sem_seg: Optional[np.ndarray] = None, + ): + """ + Args: + image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or + floating point in range [0, 1] or [0, 255]. The meaning of C is up + to users. + boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode + sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element + is an integer label of pixel. + """ + _check_img_dtype(image) + self.image = image + self.boxes = boxes + self.sem_seg = sem_seg + + def transform(self, tfm: Transform) -> None: + """ + In-place transform all attributes of this class. + + By "in-place", it means after calling this method, accessing an attribute such + as ``self.image`` will return transformed data. + """ + self.image = tfm.apply_image(self.image) + if self.boxes is not None: + self.boxes = tfm.apply_box(self.boxes) + if self.sem_seg is not None: + self.sem_seg = tfm.apply_segmentation(self.sem_seg) + + def apply_augmentations( + self, augmentations: List[Union[Augmentation, Transform]] + ) -> TransformList: + """ + Equivalent of ``AugmentationList(augmentations)(self)`` + """ + return AugmentationList(augmentations)(self) + + +def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs): + """ + Use ``T.AugmentationList(augmentations)(inputs)`` instead. + """ + if isinstance(inputs, np.ndarray): + # handle the common case of image-only Augmentation, also for backward compatibility + image_only = True + inputs = AugInput(inputs) + else: + image_only = False + tfms = inputs.apply_augmentations(augmentations) + return inputs.image if image_only else inputs, tfms + + +apply_transform_gens = apply_augmentations +""" +Alias for backward-compatibility. +""" + +TransformGen = Augmentation +""" +Alias for Augmentation, since it is something that generates :class:`Transform`s +""" + +StandardAugInput = AugInput +""" +Alias for compatibility. It's not worth the complexity to have two classes. +""" diff --git a/data_processing/detectron2/detectron2/data/transforms/augmentation_impl.py b/data_processing/detectron2/detectron2/data/transforms/augmentation_impl.py new file mode 100644 index 0000000..cc270cd --- /dev/null +++ b/data_processing/detectron2/detectron2/data/transforms/augmentation_impl.py @@ -0,0 +1,736 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Implement many useful :class:`Augmentation`. +""" +import numpy as np +import sys +from numpy import random +from typing import Tuple +import torch +from fvcore.transforms.transform import ( + BlendTransform, + CropTransform, + HFlipTransform, + NoOpTransform, + PadTransform, + Transform, + TransformList, + VFlipTransform, +) +from PIL import Image + +from detectron2.structures import Boxes, pairwise_iou + +from .augmentation import Augmentation, _transform_to_aug +from .transform import ExtentTransform, ResizeTransform, RotationTransform + +__all__ = [ + "FixedSizeCrop", + "RandomApply", + "RandomBrightness", + "RandomContrast", + "RandomCrop", + "RandomExtent", + "RandomFlip", + "RandomSaturation", + "RandomLighting", + "RandomRotation", + "Resize", + "ResizeScale", + "ResizeShortestEdge", + "RandomCrop_CategoryAreaConstraint", + "RandomResize", + "MinIoURandomCrop", +] + + +class RandomApply(Augmentation): + """ + Randomly apply an augmentation with a given probability. + """ + + def __init__(self, tfm_or_aug, prob=0.5): + """ + Args: + tfm_or_aug (Transform, Augmentation): the transform or augmentation + to be applied. It can either be a `Transform` or `Augmentation` + instance. + prob (float): probability between 0.0 and 1.0 that + the wrapper transformation is applied + """ + super().__init__() + self.aug = _transform_to_aug(tfm_or_aug) + assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})" + self.prob = prob + + def get_transform(self, *args): + do = self._rand_range() < self.prob + if do: + return self.aug.get_transform(*args) + else: + return NoOpTransform() + + def __call__(self, aug_input): + do = self._rand_range() < self.prob + if do: + return self.aug(aug_input) + else: + return NoOpTransform() + + +class RandomFlip(Augmentation): + """ + Flip the image horizontally or vertically with the given probability. + """ + + def __init__(self, prob=0.5, *, horizontal=True, vertical=False): + """ + Args: + prob (float): probability of flip. + horizontal (boolean): whether to apply horizontal flipping + vertical (boolean): whether to apply vertical flipping + """ + super().__init__() + + if horizontal and vertical: + raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") + if not horizontal and not vertical: + raise ValueError("At least one of horiz or vert has to be True!") + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + do = self._rand_range() < self.prob + if do: + if self.horizontal: + return HFlipTransform(w) + elif self.vertical: + return VFlipTransform(h) + else: + return NoOpTransform() + + +class Resize(Augmentation): + """Resize image to a fixed target size""" + + def __init__(self, shape, interp=Image.BILINEAR): + """ + Args: + shape: (h, w) tuple or a int + interp: PIL interpolation method + """ + if isinstance(shape, int): + shape = (shape, shape) + shape = tuple(shape) + self._init(locals()) + + def get_transform(self, image): + return ResizeTransform( + image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp + ) + + +class ResizeShortestEdge(Augmentation): + """ + Resize the image while keeping the aspect ratio unchanged. + It attempts to scale the shorter edge to the given `short_edge_length`, + as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. + """ + + @torch.jit.unused + def __init__( + self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR + ): + """ + Args: + short_edge_length (list[int]): If ``sample_style=="range"``, + a [min, max] interval from which to sample the shortest edge length. + If ``sample_style=="choice"``, a list of shortest edge lengths to sample from. + max_size (int): maximum allowed longest edge length. + sample_style (str): either "range" or "choice". + """ + super().__init__() + assert sample_style in ["range", "choice"], sample_style + + self.is_range = sample_style == "range" + if isinstance(short_edge_length, int): + short_edge_length = (short_edge_length, short_edge_length) + if self.is_range: + assert len(short_edge_length) == 2, ( + "short_edge_length must be two values using 'range' sample style." + f" Got {short_edge_length}!" + ) + self._init(locals()) + + @torch.jit.unused + def get_transform(self, image): + h, w = image.shape[:2] + if self.is_range: + size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1) + else: + size = np.random.choice(self.short_edge_length) + if size == 0: + return NoOpTransform() + + newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size) + return ResizeTransform(h, w, newh, neww, self.interp) + + @staticmethod + def get_output_shape( + oldh: int, oldw: int, short_edge_length: int, max_size: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target short edge length. + """ + h, w = oldh, oldw + size = short_edge_length * 1.0 + scale = size / min(h, w) + if h < w: + newh, neww = size, scale * w + else: + newh, neww = scale * h, size + if max(newh, neww) > max_size: + scale = max_size * 1.0 / max(newh, neww) + newh = newh * scale + neww = neww * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +class ResizeScale(Augmentation): + """ + Takes target size as input and randomly scales the given target size between `min_scale` + and `max_scale`. It then scales the input image such that it fits inside the scaled target + box, keeping the aspect ratio constant. + This implements the resize part of the Google's 'resize_and_crop' data augmentation: + https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127 + """ + + def __init__( + self, + min_scale: float, + max_scale: float, + target_height: int, + target_width: int, + interp: int = Image.BILINEAR, + ): + """ + Args: + min_scale: minimum image scale range. + max_scale: maximum image scale range. + target_height: target image height. + target_width: target image width. + interp: image interpolation method. + """ + super().__init__() + self._init(locals()) + + def _get_resize(self, image: np.ndarray, scale: float) -> Transform: + input_size = image.shape[:2] + + # Compute new target size given a scale. + target_size = (self.target_height, self.target_width) + target_scale_size = np.multiply(target_size, scale) + + # Compute actual rescaling applied to input image and output size. + output_scale = np.minimum( + target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1] + ) + output_size = np.round(np.multiply(input_size, output_scale)).astype(int) + + return ResizeTransform( + input_size[0], input_size[1], output_size[0], output_size[1], self.interp + ) + + def get_transform(self, image: np.ndarray) -> Transform: + random_scale = np.random.uniform(self.min_scale, self.max_scale) + return self._get_resize(image, random_scale) + + +class RandomRotation(Augmentation): + """ + This method returns a copy of this image, rotated the given + number of degrees counter clockwise around the given center. + """ + + def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None): + """ + Args: + angle (list[float]): If ``sample_style=="range"``, + a [min, max] interval from which to sample the angle (in degrees). + If ``sample_style=="choice"``, a list of angles to sample from + expand (bool): choose if the image should be resized to fit the whole + rotated image (default), or simply cropped + center (list[[float, float]]): If ``sample_style=="range"``, + a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center, + [0, 0] being the top left of the image and [1, 1] the bottom right. + If ``sample_style=="choice"``, a list of centers to sample from + Default: None, which means that the center of rotation is the center of the image + center has no effect if expand=True because it only affects shifting + """ + super().__init__() + assert sample_style in ["range", "choice"], sample_style + self.is_range = sample_style == "range" + if isinstance(angle, (float, int)): + angle = (angle, angle) + if center is not None and isinstance(center[0], (float, int)): + center = (center, center) + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + center = None + if self.is_range: + angle = np.random.uniform(self.angle[0], self.angle[1]) + if self.center is not None: + center = ( + np.random.uniform(self.center[0][0], self.center[1][0]), + np.random.uniform(self.center[0][1], self.center[1][1]), + ) + else: + angle = np.random.choice(self.angle) + if self.center is not None: + center = np.random.choice(self.center) + + if center is not None: + center = (w * center[0], h * center[1]) # Convert to absolute coordinates + + if angle % 360 == 0: + return NoOpTransform() + + return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp) + + +class FixedSizeCrop(Augmentation): + """ + If `crop_size` is smaller than the input image size, then it uses a random crop of + the crop size. If `crop_size` is larger than the input image size, then it pads + the right and the bottom of the image to the crop size if `pad` is True, otherwise + it returns the smaller image. + """ + + def __init__( + self, + crop_size: Tuple[int], + pad: bool = True, + pad_value: float = 128.0, + seg_pad_value: int = 255, + ): + """ + Args: + crop_size: target image (height, width). + pad: if True, will pad images smaller than `crop_size` up to `crop_size` + pad_value: the padding value to the image. + seg_pad_value: the padding value to the segmentation mask. + """ + super().__init__() + self._init(locals()) + + def _get_crop(self, image: np.ndarray) -> Transform: + # Compute the image scale and scaled size. + input_size = image.shape[:2] + output_size = self.crop_size + + # Add random crop if the image is scaled up. + max_offset = np.subtract(input_size, output_size) + max_offset = np.maximum(max_offset, 0) + offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0)) + offset = np.round(offset).astype(int) + return CropTransform( + offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0] + ) + + def _get_pad(self, image: np.ndarray) -> Transform: + # Compute the image scale and scaled size. + input_size = image.shape[:2] + output_size = self.crop_size + + # Add padding if the image is scaled down. + pad_size = np.subtract(output_size, input_size) + pad_size = np.maximum(pad_size, 0) + original_size = np.minimum(input_size, output_size) + return PadTransform( + 0, + 0, + pad_size[1], + pad_size[0], + original_size[1], + original_size[0], + self.pad_value, + self.seg_pad_value, + ) + + def get_transform(self, image: np.ndarray) -> TransformList: + transforms = [self._get_crop(image)] + if self.pad: + transforms.append(self._get_pad(image)) + return TransformList(transforms) + + +class RandomCrop(Augmentation): + """ + Randomly crop a rectangle region out of an image. + """ + + def __init__(self, crop_type: str, crop_size): + """ + Args: + crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range". + crop_size (tuple[float, float]): two floats, explained below. + + - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of + size (H, W). crop size should be in (0, 1] + - "relative_range": uniformly sample two values from [crop_size[0], 1] + and [crop_size[1]], 1], and use them as in "relative" crop type. + - "absolute" crop a (crop_size[0], crop_size[1]) region from input image. + crop_size must be smaller than the input image size. + - "absolute_range", for an input of size (H, W), uniformly sample H_crop in + [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])]. + Then crop a region (H_crop, W_crop). + """ + # TODO style of relative_range and absolute_range are not consistent: + # one takes (h, w) but another takes (min, max) + super().__init__() + assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"] + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + croph, cropw = self.get_crop_size((h, w)) + assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self) + h0 = np.random.randint(h - croph + 1) + w0 = np.random.randint(w - cropw + 1) + return CropTransform(w0, h0, cropw, croph) + + def get_crop_size(self, image_size): + """ + Args: + image_size (tuple): height, width + + Returns: + crop_size (tuple): height, width in absolute pixels + """ + h, w = image_size + if self.crop_type == "relative": + ch, cw = self.crop_size + return int(h * ch + 0.5), int(w * cw + 0.5) + elif self.crop_type == "relative_range": + crop_size = np.asarray(self.crop_size, dtype=np.float32) + ch, cw = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * ch + 0.5), int(w * cw + 0.5) + elif self.crop_type == "absolute": + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == "absolute_range": + assert self.crop_size[0] <= self.crop_size[1] + ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1) + cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1) + return ch, cw + else: + raise NotImplementedError("Unknown crop type {}".format(self.crop_type)) + + +class RandomCrop_CategoryAreaConstraint(Augmentation): + """ + Similar to :class:`RandomCrop`, but find a cropping window such that no single category + occupies a ratio of more than `single_category_max_area` in semantic segmentation ground + truth, which can cause unstability in training. The function attempts to find such a valid + cropping window for at most 10 times. + """ + + def __init__( + self, + crop_type: str, + crop_size, + single_category_max_area: float = 1.0, + ignored_category: int = None, + ): + """ + Args: + crop_type, crop_size: same as in :class:`RandomCrop` + single_category_max_area: the maximum allowed area ratio of a + category. Set to 1.0 to disable + ignored_category: allow this category in the semantic segmentation + ground truth to exceed the area ratio. Usually set to the category + that's ignored in training. + """ + self.crop_aug = RandomCrop(crop_type, crop_size) + self._init(locals()) + + def get_transform(self, image, sem_seg): + if self.single_category_max_area >= 1.0: + return self.crop_aug.get_transform(image) + else: + h, w = sem_seg.shape + for _ in range(10): + crop_size = self.crop_aug.get_crop_size((h, w)) + y0 = np.random.randint(h - crop_size[0] + 1) + x0 = np.random.randint(w - crop_size[1] + 1) + sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] + labels, cnt = np.unique(sem_seg_temp, return_counts=True) + if self.ignored_category is not None: + cnt = cnt[labels != self.ignored_category] + if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area: + break + crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0]) + return crop_tfm + + +class RandomExtent(Augmentation): + """ + Outputs an image by cropping a random "subrect" of the source image. + + The subrect can be parameterized to include pixels outside the source image, + in which case they will be set to zeros (i.e. black). The size of the output + image will vary with the size of the random subrect. + """ + + def __init__(self, scale_range, shift_range): + """ + Args: + output_size (h, w): Dimensions of output image + scale_range (l, h): Range of input-to-output size scaling factor + shift_range (x, y): Range of shifts of the cropped subrect. The rect + is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)], + where (w, h) is the (width, height) of the input image. Set each + component to zero to crop at the image's center. + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + img_h, img_w = image.shape[:2] + + # Initialize src_rect to fit the input image. + src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h]) + + # Apply a random scaling to the src_rect. + src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1]) + + # Apply a random shift to the coordinates origin. + src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5) + src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5) + + # Map src_rect coordinates into image coordinates (center at corner). + src_rect[0::2] += 0.5 * img_w + src_rect[1::2] += 0.5 * img_h + + return ExtentTransform( + src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]), + output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])), + ) + + +class RandomContrast(Augmentation): + """ + Randomly transforms image contrast. + + Contrast intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce contrast + - intensity = 1 will preserve the input image + - intensity > 1 will increase contrast + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation + intensity_max (float): Maximum augmentation + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + w = np.random.uniform(self.intensity_min, self.intensity_max) + return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w) + + +class RandomBrightness(Augmentation): + """ + Randomly transforms image brightness. + + Brightness intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce brightness + - intensity = 1 will preserve the input image + - intensity > 1 will increase brightness + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation + intensity_max (float): Maximum augmentation + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + w = np.random.uniform(self.intensity_min, self.intensity_max) + return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w) + + +class RandomSaturation(Augmentation): + """ + Randomly transforms saturation of an RGB image. + Input images are assumed to have 'RGB' channel order. + + Saturation intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce saturation (make the image more grayscale) + - intensity = 1 will preserve the input image + - intensity > 1 will increase saturation + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation (1 preserves input). + intensity_max (float): Maximum augmentation (1 preserves input). + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + assert image.shape[-1] == 3, "RandomSaturation only works on RGB images" + w = np.random.uniform(self.intensity_min, self.intensity_max) + grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis] + return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w) + + +class RandomLighting(Augmentation): + """ + The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet. + Input images are assumed to have 'RGB' channel order. + + The degree of color jittering is randomly sampled via a normal distribution, + with standard deviation given by the scale parameter. + """ + + def __init__(self, scale): + """ + Args: + scale (float): Standard deviation of principal component weighting. + """ + super().__init__() + self._init(locals()) + self.eigen_vecs = np.array( + [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] + ) + self.eigen_vals = np.array([0.2175, 0.0188, 0.0045]) + + def get_transform(self, image): + assert image.shape[-1] == 3, "RandomLighting only works on RGB images" + weights = np.random.normal(scale=self.scale, size=3) + return BlendTransform( + src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0 + ) + + +class RandomResize(Augmentation): + """Randomly resize image to a target size in shape_list""" + + def __init__(self, shape_list, interp=Image.BILINEAR): + """ + Args: + shape_list: a list of shapes in (h, w) + interp: PIL interpolation method + """ + self.shape_list = shape_list + self._init(locals()) + + def get_transform(self, image): + shape_idx = np.random.randint(low=0, high=len(self.shape_list)) + h, w = self.shape_list[shape_idx] + return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp) + + +class MinIoURandomCrop(Augmentation): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size) + mode_trials: number of trials for sampling min_ious threshold + crop_trials: number of trials for sampling crop_size after cropping + """ + + def __init__( + self, + min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), + min_crop_size=0.3, + mode_trials=1000, + crop_trials=50, + ): + self.min_ious = min_ious + self.sample_mode = (1, *min_ious, 0) + self.min_crop_size = min_crop_size + self.mode_trials = mode_trials + self.crop_trials = crop_trials + + def get_transform(self, image, boxes): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + boxes: ground truth boxes in (x1, y1, x2, y2) format + """ + if boxes is None: + return NoOpTransform() + h, w, c = image.shape + for _ in range(self.mode_trials): + mode = random.choice(self.sample_mode) + self.mode = mode + if mode == 1: + return NoOpTransform() + + min_iou = mode + for _ in range(self.crop_trials): + new_w = random.uniform(self.min_crop_size * w, w) + new_h = random.uniform(self.min_crop_size * h, h) + + # h / w in [0.5, 2] + if new_h / new_w < 0.5 or new_h / new_w > 2: + continue + + left = random.uniform(w - new_w) + top = random.uniform(h - new_h) + + patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h))) + # Line or point crop is not allowed + if patch[2] == patch[0] or patch[3] == patch[1]: + continue + overlaps = pairwise_iou( + Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4)) + ).reshape(-1) + if len(overlaps) > 0 and overlaps.min() < min_iou: + continue + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + if len(overlaps) > 0: + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = ( + (center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3]) + ) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + return CropTransform(int(left), int(top), int(new_w), int(new_h)) diff --git a/data_processing/detectron2/detectron2/data/transforms/transform.py b/data_processing/detectron2/detectron2/data/transforms/transform.py new file mode 100644 index 0000000..de44b99 --- /dev/null +++ b/data_processing/detectron2/detectron2/data/transforms/transform.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +See "Data Augmentation" tutorial for an overview of the system: +https://detectron2.readthedocs.io/tutorials/augmentation.html +""" + +import numpy as np +import torch +import torch.nn.functional as F +from fvcore.transforms.transform import ( + CropTransform, + HFlipTransform, + NoOpTransform, + Transform, + TransformList, +) +from PIL import Image + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + "ExtentTransform", + "ResizeTransform", + "RotationTransform", + "ColorTransform", + "PILColorTransform", +] + + +class ExtentTransform(Transform): + """ + Extracts a subregion from the source image and scales it to the output size. + + The fill color is used to map pixels from the source rect that fall outside + the source image. + + See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform + """ + + def __init__(self, src_rect, output_size, interp=Image.LINEAR, fill=0): + """ + Args: + src_rect (x0, y0, x1, y1): src coordinates + output_size (h, w): dst image size + interp: PIL interpolation methods + fill: Fill color used when src_rect extends outside image + """ + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + h, w = self.output_size + if len(img.shape) > 2 and img.shape[2] == 1: + pil_image = Image.fromarray(img[:, :, 0], mode="L") + else: + pil_image = Image.fromarray(img) + pil_image = pil_image.transform( + size=(w, h), + method=Image.EXTENT, + data=self.src_rect, + resample=interp if interp else self.interp, + fill=self.fill, + ) + ret = np.asarray(pil_image) + if len(img.shape) > 2 and img.shape[2] == 1: + ret = np.expand_dims(ret, -1) + return ret + + def apply_coords(self, coords): + # Transform image center from source coordinates into output coordinates + # and then map the new origin to the corner of the output image. + h, w = self.output_size + x0, y0, x1, y1 = self.src_rect + new_coords = coords.astype(np.float32) + new_coords[:, 0] -= 0.5 * (x0 + x1) + new_coords[:, 1] -= 0.5 * (y0 + y1) + new_coords[:, 0] *= w / (x1 - x0) + new_coords[:, 1] *= h / (y1 - y0) + new_coords[:, 0] += 0.5 * w + new_coords[:, 1] += 0.5 * h + return new_coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + +class ResizeTransform(Transform): + """ + Resize the image to a target size. + """ + + def __init__(self, h, w, new_h, new_w, interp=None): + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + # TODO decide on PIL vs opencv + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + assert img.shape[:2] == (self.h, self.w) + assert len(img.shape) <= 4 + interp_method = interp if interp is not None else self.interp + + if img.dtype == np.uint8: + if len(img.shape) > 2 and img.shape[2] == 1: + pil_image = Image.fromarray(img[:, :, 0], mode="L") + else: + pil_image = Image.fromarray(img) + pil_image = pil_image.resize((self.new_w, self.new_h), interp_method) + ret = np.asarray(pil_image) + if len(img.shape) > 2 and img.shape[2] == 1: + ret = np.expand_dims(ret, -1) + else: + # PIL only supports uint8 + if any(x < 0 for x in img.strides): + img = np.ascontiguousarray(img) + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = { + Image.NEAREST: "nearest", + Image.BILINEAR: "bilinear", + Image.BICUBIC: "bicubic", + } + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method] + align_corners = None if mode == "nearest" else False + img = F.interpolate( + img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners + ) + shape[:2] = (self.new_h, self.new_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w) + coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h) + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp) + + +class RotationTransform(Transform): + """ + This method returns a copy of this image, rotated the given + number of degrees counter clockwise around its center. + """ + + def __init__(self, h, w, angle, expand=True, center=None, interp=None): + """ + Args: + h, w (int): original image size + angle (float): degrees for rotation + expand (bool): choose if the image should be resized to fit the whole + rotated image (default), or simply cropped + center (tuple (width, height)): coordinates of the rotation center + if left to None, the center will be fit to the center of each image + center has no effect if expand=True because it only affects shifting + interp: cv2 interpolation method, default cv2.INTER_LINEAR + """ + super().__init__() + image_center = np.array((w / 2, h / 2)) + if center is None: + center = image_center + if interp is None: + interp = cv2.INTER_LINEAR + abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle)))) + if expand: + # find the new width and height bounds + bound_w, bound_h = np.rint( + [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin] + ).astype(int) + else: + bound_w, bound_h = w, h + + self._set_attributes(locals()) + self.rm_coords = self.create_rotation_matrix() + # Needed because of this problem https://github.com/opencv/opencv/issues/11784 + self.rm_image = self.create_rotation_matrix(offset=-0.5) + + def apply_image(self, img, interp=None): + """ + img should be a numpy array, formatted as Height * Width * Nchannels + """ + if len(img) == 0 or self.angle % 360 == 0: + return img + assert img.shape[:2] == (self.h, self.w) + interp = interp if interp is not None else self.interp + return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp) + + def apply_coords(self, coords): + """ + coords should be a N * 2 array-like, containing N couples of (x, y) points + """ + coords = np.asarray(coords, dtype=float) + if len(coords) == 0 or self.angle % 360 == 0: + return coords + return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :] + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST) + return segmentation + + def create_rotation_matrix(self, offset=0): + center = (self.center[0] + offset, self.center[1] + offset) + rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1) + if self.expand: + # Find the coordinates of the center of rotation in the new image + # The only point for which we know the future coordinates is the center of the image + rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :] + new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center + # shift the rotation center to the new coordinates + rm[:, 2] += new_center + return rm + + def inverse(self): + """ + The inverse is to rotate it back with expand, and crop to get the original shape. + """ + if not self.expand: # Not possible to inverse if a part of the image is lost + raise NotImplementedError() + rotation = RotationTransform( + self.bound_h, self.bound_w, -self.angle, True, None, self.interp + ) + crop = CropTransform( + (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h + ) + return TransformList([rotation, crop]) + + +class ColorTransform(Transform): + """ + Generic wrapper for any photometric transforms. + These transformations should only affect the color space and + not the coordinate space of the image (e.g. annotation + coordinates such as bounding boxes should not be changed) + """ + + def __init__(self, op): + """ + Args: + op (Callable): operation to be applied to the image, + which takes in an ndarray and returns an ndarray. + """ + if not callable(op): + raise ValueError("op parameter should be callable") + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img): + return self.op(img) + + def apply_coords(self, coords): + return coords + + def inverse(self): + return NoOpTransform() + + def apply_segmentation(self, segmentation): + return segmentation + + +class PILColorTransform(ColorTransform): + """ + Generic wrapper for PIL Photometric image transforms, + which affect the color space and not the coordinate + space of the image + """ + + def __init__(self, op): + """ + Args: + op (Callable): operation to be applied to the image, + which takes in a PIL Image and returns a transformed + PIL Image. + For reference on possible operations see: + - https://pillow.readthedocs.io/en/stable/ + """ + if not callable(op): + raise ValueError("op parameter should be callable") + super().__init__(op) + + def apply_image(self, img): + img = Image.fromarray(img) + return np.asarray(super().apply_image(img)) + + +def HFlip_rotated_box(transform, rotated_boxes): + """ + Apply the horizontal flip transform on rotated boxes. + + Args: + rotated_boxes (ndarray): Nx5 floating point array of + (x_center, y_center, width, height, angle_degrees) format + in absolute coordinates. + """ + # Transform x_center + rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0] + # Transform angle + rotated_boxes[:, 4] = -rotated_boxes[:, 4] + return rotated_boxes + + +def Resize_rotated_box(transform, rotated_boxes): + """ + Apply the resizing transform on rotated boxes. For details of how these (approximation) + formulas are derived, please refer to :meth:`RotatedBoxes.scale`. + + Args: + rotated_boxes (ndarray): Nx5 floating point array of + (x_center, y_center, width, height, angle_degrees) format + in absolute coordinates. + """ + scale_factor_x = transform.new_w * 1.0 / transform.w + scale_factor_y = transform.new_h * 1.0 / transform.h + rotated_boxes[:, 0] *= scale_factor_x + rotated_boxes[:, 1] *= scale_factor_y + theta = rotated_boxes[:, 4] * np.pi / 180.0 + c = np.cos(theta) + s = np.sin(theta) + rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s)) + rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c)) + rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi + + return rotated_boxes + + +HFlipTransform.register_type("rotated_box", HFlip_rotated_box) +ResizeTransform.register_type("rotated_box", Resize_rotated_box) + +# not necessary any more with latest fvcore +NoOpTransform.register_type("rotated_box", lambda t, x: x) diff --git a/data_processing/detectron2/detectron2/engine/__init__.py b/data_processing/detectron2/detectron2/engine/__init__.py new file mode 100644 index 0000000..08a6157 --- /dev/null +++ b/data_processing/detectron2/detectron2/engine/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .launch import * +from .train_loop import * + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__) +# but still make them available here +from .hooks import * +from .defaults import * diff --git a/data_processing/detectron2/detectron2/engine/defaults.py b/data_processing/detectron2/detectron2/engine/defaults.py new file mode 100644 index 0000000..5b95257 --- /dev/null +++ b/data_processing/detectron2/detectron2/engine/defaults.py @@ -0,0 +1,715 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +This file contains components with some default boilerplate logic user may need +in training / testing. They will not work for everyone, but many users may find them useful. + +The behavior of functions/classes in this file is subject to change, +since they are meant to represent the "common default behavior" people need in their projects. +""" + +import argparse +import logging +import os +import sys +import weakref +from collections import OrderedDict +from typing import Optional +import torch +from fvcore.nn.precise_bn import get_bn_modules +from omegaconf import OmegaConf +from torch.nn.parallel import DistributedDataParallel + +import detectron2.data.transforms as T +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import CfgNode, LazyConfig +from detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.evaluation import ( + DatasetEvaluator, + inference_on_dataset, + print_csv_format, + verify_results, +) +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +from detectron2.utils import comm +from detectron2.utils.collect_env import collect_env_info +from detectron2.utils.env import seed_all_rng +from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + +from . import hooks +from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase + +__all__ = [ + "create_ddp_model", + "default_argument_parser", + "default_setup", + "default_writers", + "DefaultPredictor", + "DefaultTrainer", +] + + +def create_ddp_model(model, *, fp16_compression=False, **kwargs): + """ + Create a DistributedDataParallel model if there are >1 processes. + + Args: + model: a torch.nn.Module + fp16_compression: add fp16 compression hooks to the ddp object. + See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook + kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. + """ # noqa + if comm.get_world_size() == 1: + return model + if "device_ids" not in kwargs: + kwargs["device_ids"] = [comm.get_local_rank()] + ddp = DistributedDataParallel(model, **kwargs) + if fp16_compression: + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + + ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) + return ddp + + +def default_argument_parser(epilog=None): + """ + Create a parser with some common arguments used by detectron2 users. + + Args: + epilog (str): epilog passed to ArgumentParser describing the usage. + + Returns: + argparse.ArgumentParser: + """ + parser = argparse.ArgumentParser( + epilog=epilog + or f""" +Examples: + +Run on single machine: + $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml + +Change some config options: + $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 + +Run on multiple machines: + (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--resume", + action="store_true", + help="Whether to attempt to resume from the checkpoint directory. " + "See documentation of `DefaultTrainer.resume_or_load()` for what it means.", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") + parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") + parser.add_argument( + "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" + ) + + # PyTorch still may leave orphan processes in multi-gpu training. + # Therefore we use a deterministic way to obtain port, + # so that users are aware of orphan processes by seeing the port occupied. + port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14 + parser.add_argument( + "--dist-url", + default="tcp://127.0.0.1:{}".format(port), + help="initialization URL for pytorch distributed backend. See " + "https://pytorch.org/docs/stable/distributed.html for details.", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + return parser + + +def _try_get_key(cfg, *keys, default=None): + """ + Try select keys from cfg until the first key that exists. Otherwise return default. + """ + if isinstance(cfg, CfgNode): + cfg = OmegaConf.create(cfg.dump()) + for k in keys: + none = object() + p = OmegaConf.select(cfg, k, default=none) + if p is not none: + return p + return default + + +def _highlight(code, filename): + try: + import pygments + except ImportError: + return code + + from pygments.lexers import Python3Lexer, YamlLexer + from pygments.formatters import Terminal256Formatter + + lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() + code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) + return code + + +def default_setup(cfg, args): + """ + Perform some basic common setups at the beginning of a job, including: + + 1. Set up the detectron2 logger + 2. Log basic information about environment, cmdline arguments, and config + 3. Backup the config to the output directory + + Args: + cfg (CfgNode or omegaconf.DictConfig): the full config to be used + args (argparse.NameSpace): the command line arguments to be logged + """ + output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir") + if comm.is_main_process() and output_dir: + PathManager.mkdirs(output_dir) + + rank = comm.get_rank() + setup_logger(output_dir, distributed_rank=rank, name="fvcore") + logger = setup_logger(output_dir, distributed_rank=rank) + + logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) + logger.info("Environment info:\n" + collect_env_info()) + + logger.info("Command line arguments: " + str(args)) + if hasattr(args, "config_file") and args.config_file != "": + logger.info( + "Contents of args.config_file={}:\n{}".format( + args.config_file, + _highlight(PathManager.open(args.config_file, "r").read(), args.config_file), + ) + ) + + if comm.is_main_process() and output_dir: + # Note: some of our scripts may expect the existence of + # config.yaml in output directory + path = os.path.join(output_dir, "config.yaml") + if isinstance(cfg, CfgNode): + logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml"))) + with PathManager.open(path, "w") as f: + f.write(cfg.dump()) + else: + LazyConfig.save(cfg, path) + logger.info("Full config saved to {}".format(path)) + + # make sure each worker has a different, yet deterministic seed if specified + seed = _try_get_key(cfg, "SEED", "train.seed", default=-1) + seed_all_rng(None if seed < 0 else seed + rank) + + # cudnn benchmark has large overhead. It shouldn't be used considering the small size of + # typical validation set. + if not (hasattr(args, "eval_only") and args.eval_only): + torch.backends.cudnn.benchmark = _try_get_key( + cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False + ) + + +def default_writers(output_dir: str, max_iter: Optional[int] = None): + """ + Build a list of :class:`EventWriter` to be used. + It now consists of a :class:`CommonMetricPrinter`, + :class:`TensorboardXWriter` and :class:`JSONWriter`. + + Args: + output_dir: directory to store JSON metrics and tensorboard events + max_iter: the total number of iterations + + Returns: + list[EventWriter]: a list of :class:`EventWriter` objects. + """ + PathManager.mkdirs(output_dir) + return [ + # It may not always print what you want to see, since it prints "common" metrics only. + CommonMetricPrinter(max_iter), + JSONWriter(os.path.join(output_dir, "metrics.json")), + TensorboardXWriter(output_dir), + ] + + +class DefaultPredictor: + """ + Create a simple end-to-end predictor with the given config that runs on + single device for a single input image. + + Compared to using the model directly, this class does the following additions: + + 1. Load checkpoint from `cfg.MODEL.WEIGHTS`. + 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`. + 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`. + 4. Take one input image and produce a single output, instead of a batch. + + This is meant for simple demo purposes, so it does the above steps automatically. + This is not meant for benchmarks or running complicated inference logic. + If you'd like to do anything more complicated, please refer to its source code as + examples to build and use the model manually. + + Attributes: + metadata (Metadata): the metadata of the underlying dataset, obtained from + cfg.DATASETS.TEST. + + Examples: + :: + pred = DefaultPredictor(cfg) + inputs = cv2.imread("input.jpg") + outputs = pred(inputs) + """ + + def __init__(self, cfg): + self.cfg = cfg.clone() # cfg can be modified by model + self.model = build_model(self.cfg) + self.model.eval() + if len(cfg.DATASETS.TEST): + self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) + + checkpointer = DetectionCheckpointer(self.model) + checkpointer.load(cfg.MODEL.WEIGHTS) + + self.aug = T.ResizeShortestEdge( + [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST + ) + + self.input_format = cfg.INPUT.FORMAT + assert self.input_format in ["RGB", "BGR"], self.input_format + + def __call__(self, original_image): + """ + Args: + original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). + + Returns: + predictions (dict): + the output of the model for one image only. + See :doc:`/tutorials/models` for details about the format. + """ + with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 + # Apply pre-processing to image. + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + image = self.aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + inputs = {"image": image, "height": height, "width": width} + predictions = self.model([inputs])[0] + return predictions + + +class DefaultTrainer(TrainerBase): + """ + A trainer with default training logic. It does the following: + + 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader + defined by the given config. Create a LR scheduler defined by the config. + 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when + `resume_or_load` is called. + 3. Register a few common hooks defined by the config. + + It is created to simplify the **standard model training workflow** and reduce code boilerplate + for users who only need the standard training workflow, with standard features. + It means this class makes *many assumptions* about your training logic that + may easily become invalid in a new research. In fact, any assumptions beyond those made in the + :class:`SimpleTrainer` are too much for research. + + The code of this class has been annotated about restrictive assumptions it makes. + When they do not work for you, you're encouraged to: + + 1. Overwrite methods of this class, OR: + 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and + nothing else. You can then add your own hooks if needed. OR: + 3. Write your own training loop similar to `tools/plain_train_net.py`. + + See the :doc:`/tutorials/training` tutorials for more details. + + Note that the behavior of this class, like other functions/classes in + this file, is not stable, since it is meant to represent the "common default behavior". + It is only guaranteed to work well with the standard models and training workflow in detectron2. + To obtain more stable behavior, write your own training logic with other public APIs. + + Examples: + :: + trainer = DefaultTrainer(cfg) + trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS + trainer.train() + + Attributes: + scheduler: + checkpointer (DetectionCheckpointer): + cfg (CfgNode): + """ + + def __init__(self, cfg): + """ + Args: + cfg (CfgNode): + """ + super().__init__() + logger = logging.getLogger("detectron2") + if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 + setup_logger() + cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + + # Assume these objects must be constructed in this order. + model = self.build_model(cfg) + optimizer = self.build_optimizer(cfg, model) + data_loader = self.build_train_loader(cfg) + + model = create_ddp_model(model, broadcast_buffers=False) + self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( + model, data_loader, optimizer + ) + + self.scheduler = self.build_lr_scheduler(cfg, optimizer) + self.checkpointer = DetectionCheckpointer( + # Assume you want to save checkpoints together with logs/statistics + model, + cfg.OUTPUT_DIR, + trainer=weakref.proxy(self), + ) + self.start_iter = 0 + self.max_iter = cfg.SOLVER.MAX_ITER + self.cfg = cfg + + self.register_hooks(self.build_hooks()) + + def resume_or_load(self, resume=True): + """ + If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by + a `last_checkpoint` file), resume from the file. Resuming means loading all + available states (eg. optimizer and scheduler) and update iteration counter + from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used. + + Otherwise, this is considered as an independent training. The method will load model + weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start + from iteration 0. + + Args: + resume (bool): whether to do resume or not + """ + self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) + if resume and self.checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + self.start_iter = self.iter + 1 + + def build_hooks(self): + """ + Build a list of default hooks, including timing, evaluation, + checkpointing, lr scheduling, precise BN, writing events. + + Returns: + list[HookBase]: + """ + cfg = self.cfg.clone() + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN + + ret = [ + hooks.IterationTimer(), + hooks.LRScheduler(), + hooks.PreciseBN( + # Run at the same freq as (but before) evaluation. + cfg.TEST.EVAL_PERIOD, + self.model, + # Build a new data loader to not affect training + self.build_train_loader(cfg), + cfg.TEST.PRECISE_BN.NUM_ITER, + ) + if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) + else None, + ] + + # Do PreciseBN before checkpointer, because it updates the model and need to + # be saved by checkpointer. + # This is not always the best: if checkpointing has a different frequency, + # some checkpoints may have more precise statistics than others. + if comm.is_main_process(): + ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) + + def test_and_save_results(): + self._last_eval_results = self.test(self.cfg, self.model) + return self._last_eval_results + + # Do evaluation after checkpointer, because then if it fails, + # we can use the saved checkpoint to debug. + ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) + + if comm.is_main_process(): + # Here the default print/log frequency of each writer is used. + # run writers in the end, so that evaluation metrics are written + ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) + return ret + + def build_writers(self): + """ + Build a list of writers to be used using :func:`default_writers()`. + If you'd like a different list of writers, you can overwrite it in + your trainer. + + Returns: + list[EventWriter]: a list of :class:`EventWriter` objects. + """ + return default_writers(self.cfg.OUTPUT_DIR, self.max_iter) + + def train(self): + """ + Run training. + + Returns: + OrderedDict of results, if evaluation is enabled. Otherwise None. + """ + super().train(self.start_iter, self.max_iter) + if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process(): + assert hasattr( + self, "_last_eval_results" + ), "No evaluation results obtained during training!" + verify_results(self.cfg, self._last_eval_results) + return self._last_eval_results + + def run_step(self): + self._trainer.iter = self.iter + self._trainer.run_step() + + def state_dict(self): + ret = super().state_dict() + ret["_trainer"] = self._trainer.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self._trainer.load_state_dict(state_dict["_trainer"]) + + @classmethod + def build_model(cls, cfg): + """ + Returns: + torch.nn.Module: + + It now calls :func:`detectron2.modeling.build_model`. + Overwrite it if you'd like a different model. + """ + model = build_model(cfg) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + return model + + @classmethod + def build_optimizer(cls, cfg, model): + """ + Returns: + torch.optim.Optimizer: + + It now calls :func:`detectron2.solver.build_optimizer`. + Overwrite it if you'd like a different optimizer. + """ + return build_optimizer(cfg, model) + + @classmethod + def build_lr_scheduler(cls, cfg, optimizer): + """ + It now calls :func:`detectron2.solver.build_lr_scheduler`. + Overwrite it if you'd like a different scheduler. + """ + return build_lr_scheduler(cfg, optimizer) + + @classmethod + def build_train_loader(cls, cfg): + """ + Returns: + iterable + + It now calls :func:`detectron2.data.build_detection_train_loader`. + Overwrite it if you'd like a different data loader. + """ + return build_detection_train_loader(cfg) + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + """ + Returns: + iterable + + It now calls :func:`detectron2.data.build_detection_test_loader`. + Overwrite it if you'd like a different data loader. + """ + return build_detection_test_loader(cfg, dataset_name) + + @classmethod + def build_evaluator(cls, cfg, dataset_name): + """ + Returns: + DatasetEvaluator or None + + It is not implemented by default. + """ + raise NotImplementedError( + """ +If you want DefaultTrainer to automatically run evaluation, +please implement `build_evaluator()` in subclasses (see train_net.py for example). +Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example). +""" + ) + + @classmethod + def test(cls, cfg, model, evaluators=None): + """ + Evaluate the given model. The given model is expected to already contain + weights to evaluate. + + Args: + cfg (CfgNode): + model (nn.Module): + evaluators (list[DatasetEvaluator] or None): if None, will call + :meth:`build_evaluator`. Otherwise, must have the same length as + ``cfg.DATASETS.TEST``. + + Returns: + dict: a dict of result metrics + """ + logger = logging.getLogger(__name__) + if isinstance(evaluators, DatasetEvaluator): + evaluators = [evaluators] + if evaluators is not None: + assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( + len(cfg.DATASETS.TEST), len(evaluators) + ) + + results = OrderedDict() + for idx, dataset_name in enumerate(cfg.DATASETS.TEST): + data_loader = cls.build_test_loader(cfg, dataset_name) + # When evaluators are passed in as arguments, + # implicitly assume that evaluators can be created before data_loader. + if evaluators is not None: + evaluator = evaluators[idx] + else: + try: + evaluator = cls.build_evaluator(cfg, dataset_name) + except NotImplementedError: + logger.warn( + "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " + "or implement its `build_evaluator` method." + ) + results[dataset_name] = {} + continue + results_i = inference_on_dataset(model, data_loader, evaluator) + results[dataset_name] = results_i + if comm.is_main_process(): + assert isinstance( + results_i, dict + ), "Evaluator must return a dict on the main process. Got {} instead.".format( + results_i + ) + logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + print_csv_format(results_i) + + if len(results) == 1: + results = list(results.values())[0] + return results + + @staticmethod + def auto_scale_workers(cfg, num_workers: int): + """ + When the config is defined for certain number of workers (according to + ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of + workers currently in use, returns a new cfg where the total batch size + is scaled so that the per-GPU batch size stays the same as the + original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``. + + Other config options are also scaled accordingly: + * training steps and warmup steps are scaled inverse proportionally. + * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`. + + For example, with the original config like the following: + + .. code-block:: yaml + + IMS_PER_BATCH: 16 + BASE_LR: 0.1 + REFERENCE_WORLD_SIZE: 8 + MAX_ITER: 5000 + STEPS: (4000,) + CHECKPOINT_PERIOD: 1000 + + When this config is used on 16 GPUs instead of the reference number 8, + calling this method will return a new config with: + + .. code-block:: yaml + + IMS_PER_BATCH: 32 + BASE_LR: 0.2 + REFERENCE_WORLD_SIZE: 16 + MAX_ITER: 2500 + STEPS: (2000,) + CHECKPOINT_PERIOD: 500 + + Note that both the original config and this new config can be trained on 16 GPUs. + It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``). + + Returns: + CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. + """ + old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE + if old_world_size == 0 or old_world_size == num_workers: + return cfg + cfg = cfg.clone() + frozen = cfg.is_frozen() + cfg.defrost() + + assert ( + cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 + ), "Invalid REFERENCE_WORLD_SIZE in config!" + scale = num_workers / old_world_size + bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) + lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale + max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) + warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) + cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) + cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) + cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) + cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant + logger = logging.getLogger(__name__) + logger.info( + f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " + f"max_iter={max_iter}, warmup={warmup_iter}." + ) + + if frozen: + cfg.freeze() + return cfg + + +# Access basic attributes from the underlying trainer +for _attr in ["model", "data_loader", "optimizer"]: + setattr( + DefaultTrainer, + _attr, + property( + # getter + lambda self, x=_attr: getattr(self._trainer, x), + # setter + lambda self, value, x=_attr: setattr(self._trainer, x, value), + ), + ) diff --git a/data_processing/detectron2/detectron2/engine/hooks.py b/data_processing/detectron2/detectron2/engine/hooks.py new file mode 100644 index 0000000..fc37af0 --- /dev/null +++ b/data_processing/detectron2/detectron2/engine/hooks.py @@ -0,0 +1,690 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import datetime +import itertools +import logging +import math +import operator +import os +import tempfile +import time +import warnings +from collections import Counter +import torch +from fvcore.common.checkpoint import Checkpointer +from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer +from fvcore.common.param_scheduler import ParamScheduler +from fvcore.common.timer import Timer +from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats + +import detectron2.utils.comm as comm +from detectron2.evaluation.testing import flatten_results_dict +from detectron2.solver import LRMultiplier +from detectron2.solver import LRScheduler as _LRScheduler +from detectron2.utils.events import EventStorage, EventWriter +from detectron2.utils.file_io import PathManager + +from .train_loop import HookBase + +__all__ = [ + "CallbackHook", + "IterationTimer", + "PeriodicWriter", + "PeriodicCheckpointer", + "BestCheckpointer", + "LRScheduler", + "AutogradProfiler", + "EvalHook", + "PreciseBN", + "TorchProfiler", + "TorchMemoryStats", +] + + +""" +Implement some common hooks. +""" + + +class CallbackHook(HookBase): + """ + Create a hook using callback functions provided by the user. + """ + + def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): + """ + Each argument is a function that takes one argument: the trainer. + """ + self._before_train = before_train + self._before_step = before_step + self._after_step = after_step + self._after_train = after_train + + def before_train(self): + if self._before_train: + self._before_train(self.trainer) + + def after_train(self): + if self._after_train: + self._after_train(self.trainer) + # The functions may be closures that hold reference to the trainer + # Therefore, delete them to avoid circular reference. + del self._before_train, self._after_train + del self._before_step, self._after_step + + def before_step(self): + if self._before_step: + self._before_step(self.trainer) + + def after_step(self): + if self._after_step: + self._after_step(self.trainer) + + +class IterationTimer(HookBase): + """ + Track the time spent for each iteration (each run_step call in the trainer). + Print a summary in the end of training. + + This hook uses the time between the call to its :meth:`before_step` + and :meth:`after_step` methods. + Under the convention that :meth:`before_step` of all hooks should only + take negligible amount of time, the :class:`IterationTimer` hook should be + placed at the beginning of the list of hooks to obtain accurate timing. + """ + + def __init__(self, warmup_iter=3): + """ + Args: + warmup_iter (int): the number of iterations at the beginning to exclude + from timing. + """ + self._warmup_iter = warmup_iter + self._step_timer = Timer() + self._start_time = time.perf_counter() + self._total_timer = Timer() + + def before_train(self): + self._start_time = time.perf_counter() + self._total_timer.reset() + self._total_timer.pause() + + def after_train(self): + logger = logging.getLogger(__name__) + total_time = time.perf_counter() - self._start_time + total_time_minus_hooks = self._total_timer.seconds() + hook_time = total_time - total_time_minus_hooks + + num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter + + if num_iter > 0 and total_time_minus_hooks > 0: + # Speed is meaningful only after warmup + # NOTE this format is parsed by grep in some scripts + logger.info( + "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( + num_iter, + str(datetime.timedelta(seconds=int(total_time_minus_hooks))), + total_time_minus_hooks / num_iter, + ) + ) + + logger.info( + "Total training time: {} ({} on hooks)".format( + str(datetime.timedelta(seconds=int(total_time))), + str(datetime.timedelta(seconds=int(hook_time))), + ) + ) + + def before_step(self): + self._step_timer.reset() + self._total_timer.resume() + + def after_step(self): + # +1 because we're in after_step, the current step is done + # but not yet counted + iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1 + if iter_done >= self._warmup_iter: + sec = self._step_timer.seconds() + self.trainer.storage.put_scalars(time=sec) + else: + self._start_time = time.perf_counter() + self._total_timer.reset() + + self._total_timer.pause() + + +class PeriodicWriter(HookBase): + """ + Write events to EventStorage (by calling ``writer.write()``) periodically. + + It is executed every ``period`` iterations and after the last iteration. + Note that ``period`` does not affect how data is smoothed by each writer. + """ + + def __init__(self, writers, period=20): + """ + Args: + writers (list[EventWriter]): a list of EventWriter objects + period (int): + """ + self._writers = writers + for w in writers: + assert isinstance(w, EventWriter), w + self._period = period + + def after_step(self): + if (self.trainer.iter + 1) % self._period == 0 or ( + self.trainer.iter == self.trainer.max_iter - 1 + ): + for writer in self._writers: + writer.write() + + def after_train(self): + for writer in self._writers: + # If any new data is found (e.g. produced by other after_train), + # write them before closing + writer.write() + writer.close() + + +class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): + """ + Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook. + + Note that when used as a hook, + it is unable to save additional data other than what's defined + by the given `checkpointer`. + + It is executed every ``period`` iterations and after the last iteration. + """ + + def before_train(self): + self.max_iter = self.trainer.max_iter + + def after_step(self): + # No way to use **kwargs + self.step(self.trainer.iter) + + +class BestCheckpointer(HookBase): + """ + Checkpoints best weights based off given metric. + + This hook should be used in conjunction to and executed after the hook + that produces the metric, e.g. `EvalHook`. + """ + + def __init__( + self, + eval_period: int, + checkpointer: Checkpointer, + val_metric: str, + mode: str = "max", + file_prefix: str = "model_best", + ) -> None: + """ + Args: + eval_period (int): the period `EvalHook` is set to run. + checkpointer: the checkpointer object used to save checkpoints. + val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50" + mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be + maximized or minimized, e.g. for "bbox/AP50" it should be "max" + file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" + """ + self._logger = logging.getLogger(__name__) + self._period = eval_period + self._val_metric = val_metric + assert mode in [ + "max", + "min", + ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' + if mode == "max": + self._compare = operator.gt + else: + self._compare = operator.lt + self._checkpointer = checkpointer + self._file_prefix = file_prefix + self.best_metric = None + self.best_iter = None + + def _update_best(self, val, iteration): + if math.isnan(val) or math.isinf(val): + return False + self.best_metric = val + self.best_iter = iteration + return True + + def _best_checking(self): + metric_tuple = self.trainer.storage.latest().get(self._val_metric) + if metric_tuple is None: + self._logger.warning( + f"Given val metric {self._val_metric} does not seem to be computed/stored." + "Will not be checkpointing based on it." + ) + return + else: + latest_metric, metric_iter = metric_tuple + + if self.best_metric is None: + if self._update_best(latest_metric, metric_iter): + additional_state = {"iteration": metric_iter} + self._checkpointer.save(f"{self._file_prefix}", **additional_state) + self._logger.info( + f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" + ) + elif self._compare(latest_metric, self.best_metric): + additional_state = {"iteration": metric_iter} + self._checkpointer.save(f"{self._file_prefix}", **additional_state) + self._logger.info( + f"Saved best model as latest eval score for {self._val_metric} is " + f"{latest_metric:0.5f}, better than last best score " + f"{self.best_metric:0.5f} @ iteration {self.best_iter}." + ) + self._update_best(latest_metric, metric_iter) + else: + self._logger.info( + f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, " + f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}." + ) + + def after_step(self): + # same conditions as `EvalHook` + next_iter = self.trainer.iter + 1 + if ( + self._period > 0 + and next_iter % self._period == 0 + and next_iter != self.trainer.max_iter + ): + self._best_checking() + + def after_train(self): + # same conditions as `EvalHook` + if self.trainer.iter + 1 >= self.trainer.max_iter: + self._best_checking() + + +class LRScheduler(HookBase): + """ + A hook which executes a torch builtin LR scheduler and summarizes the LR. + It is executed after every iteration. + """ + + def __init__(self, optimizer=None, scheduler=None): + """ + Args: + optimizer (torch.optim.Optimizer): + scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): + if a :class:`ParamScheduler` object, it defines the multiplier over the base LR + in the optimizer. + + If any argument is not given, will try to obtain it from the trainer. + """ + self._optimizer = optimizer + self._scheduler = scheduler + + def before_train(self): + self._optimizer = self._optimizer or self.trainer.optimizer + if isinstance(self.scheduler, ParamScheduler): + self._scheduler = LRMultiplier( + self._optimizer, + self.scheduler, + self.trainer.max_iter, + last_iter=self.trainer.iter - 1, + ) + self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer) + + @staticmethod + def get_best_param_group_id(optimizer): + # NOTE: some heuristics on what LR to summarize + # summarize the param group with most parameters + largest_group = max(len(g["params"]) for g in optimizer.param_groups) + + if largest_group == 1: + # If all groups have one parameter, + # then find the most common initial LR, and use it for summary + lr_count = Counter([g["lr"] for g in optimizer.param_groups]) + lr = lr_count.most_common()[0][0] + for i, g in enumerate(optimizer.param_groups): + if g["lr"] == lr: + return i + else: + for i, g in enumerate(optimizer.param_groups): + if len(g["params"]) == largest_group: + return i + + def after_step(self): + lr = self._optimizer.param_groups[self._best_param_group_id]["lr"] + self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) + self.scheduler.step() + + @property + def scheduler(self): + return self._scheduler or self.trainer.scheduler + + def state_dict(self): + if isinstance(self.scheduler, _LRScheduler): + return self.scheduler.state_dict() + return {} + + def load_state_dict(self, state_dict): + if isinstance(self.scheduler, _LRScheduler): + logger = logging.getLogger(__name__) + logger.info("Loading scheduler from state_dict ...") + self.scheduler.load_state_dict(state_dict) + + +class TorchProfiler(HookBase): + """ + A hook which runs `torch.profiler.profile`. + + Examples: + :: + hooks.TorchProfiler( + lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR + ) + + The above example will run the profiler for iteration 10~20 and dump + results to ``OUTPUT_DIR``. We did not profile the first few iterations + because they are typically slower than the rest. + The result files can be loaded in the ``chrome://tracing`` page in chrome browser, + and the tensorboard visualizations can be visualized using + ``tensorboard --logdir OUTPUT_DIR/log`` + """ + + def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True): + """ + Args: + enable_predicate (callable[trainer -> bool]): a function which takes a trainer, + and returns whether to enable the profiler. + It will be called once every step, and can be used to select which steps to profile. + output_dir (str): the output directory to dump tracing files. + activities (iterable): same as in `torch.profiler.profile`. + save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/ + """ + self._enable_predicate = enable_predicate + self._activities = activities + self._output_dir = output_dir + self._save_tensorboard = save_tensorboard + + def before_step(self): + if self._enable_predicate(self.trainer): + if self._save_tensorboard: + on_trace_ready = torch.profiler.tensorboard_trace_handler( + os.path.join( + self._output_dir, + "log", + "profiler-tensorboard-iter{}".format(self.trainer.iter), + ), + f"worker{comm.get_rank()}", + ) + else: + on_trace_ready = None + self._profiler = torch.profiler.profile( + activities=self._activities, + on_trace_ready=on_trace_ready, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) + self._profiler.__enter__() + else: + self._profiler = None + + def after_step(self): + if self._profiler is None: + return + self._profiler.__exit__(None, None, None) + if not self._save_tensorboard: + PathManager.mkdirs(self._output_dir) + out_file = os.path.join( + self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter) + ) + if "://" not in out_file: + self._profiler.export_chrome_trace(out_file) + else: + # Support non-posix filesystems + with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d: + tmp_file = os.path.join(d, "tmp.json") + self._profiler.export_chrome_trace(tmp_file) + with open(tmp_file) as f: + content = f.read() + with PathManager.open(out_file, "w") as f: + f.write(content) + + +class AutogradProfiler(TorchProfiler): + """ + A hook which runs `torch.autograd.profiler.profile`. + + Examples: + :: + hooks.AutogradProfiler( + lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR + ) + + The above example will run the profiler for iteration 10~20 and dump + results to ``OUTPUT_DIR``. We did not profile the first few iterations + because they are typically slower than the rest. + The result files can be loaded in the ``chrome://tracing`` page in chrome browser. + + Note: + When used together with NCCL on older version of GPUs, + autograd profiler may cause deadlock because it unnecessarily allocates + memory on every device it sees. The memory management calls, if + interleaved with NCCL calls, lead to deadlock on GPUs that do not + support ``cudaLaunchCooperativeKernelMultiDevice``. + """ + + def __init__(self, enable_predicate, output_dir, *, use_cuda=True): + """ + Args: + enable_predicate (callable[trainer -> bool]): a function which takes a trainer, + and returns whether to enable the profiler. + It will be called once every step, and can be used to select which steps to profile. + output_dir (str): the output directory to dump tracing files. + use_cuda (bool): same as in `torch.autograd.profiler.profile`. + """ + warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.") + self._enable_predicate = enable_predicate + self._use_cuda = use_cuda + self._output_dir = output_dir + + def before_step(self): + if self._enable_predicate(self.trainer): + self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda) + self._profiler.__enter__() + else: + self._profiler = None + + +class EvalHook(HookBase): + """ + Run an evaluation function periodically, and at the end of training. + + It is executed every ``eval_period`` iterations and after the last iteration. + """ + + def __init__(self, eval_period, eval_function, eval_after_train=True): + """ + Args: + eval_period (int): the period to run `eval_function`. Set to 0 to + not evaluate periodically (but still evaluate after the last iteration + if `eval_after_train` is True). + eval_function (callable): a function which takes no arguments, and + returns a nested dict of evaluation metrics. + eval_after_train (bool): whether to evaluate after the last iteration + + Note: + This hook must be enabled in all or none workers. + If you would like only certain workers to perform evaluation, + give other workers a no-op function (`eval_function=lambda: None`). + """ + self._period = eval_period + self._func = eval_function + self._eval_after_train = eval_after_train + + def _do_eval(self): + results = self._func() + + if results: + assert isinstance( + results, dict + ), "Eval function must return a dict. Got {} instead.".format(results) + + flattened_results = flatten_results_dict(results) + for k, v in flattened_results.items(): + try: + v = float(v) + except Exception as e: + raise ValueError( + "[EvalHook] eval_function should return a nested dict of float. " + "Got '{}: {}' instead.".format(k, v) + ) from e + self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) + + # Evaluation may take different time among workers. + # A barrier make them start the next iteration together. + comm.synchronize() + + def after_step(self): + next_iter = self.trainer.iter + 1 + if self._period > 0 and next_iter % self._period == 0: + # do the last eval in after_train + if next_iter != self.trainer.max_iter: + self._do_eval() + + def after_train(self): + # This condition is to prevent the eval from running after a failed training + if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter: + self._do_eval() + # func is likely a closure that holds reference to the trainer + # therefore we clean it to avoid circular reference in the end + del self._func + + +class PreciseBN(HookBase): + """ + The standard implementation of BatchNorm uses EMA in inference, which is + sometimes suboptimal. + This class computes the true average of statistics rather than the moving average, + and put true averages to every BN layer in the given model. + + It is executed every ``period`` iterations and after the last iteration. + """ + + def __init__(self, period, model, data_loader, num_iter): + """ + Args: + period (int): the period this hook is run, or 0 to not run during training. + The hook will always run in the end of training. + model (nn.Module): a module whose all BN layers in training mode will be + updated by precise BN. + Note that user is responsible for ensuring the BN layers to be + updated are in training mode when this hook is triggered. + data_loader (iterable): it will produce data to be run by `model(data)`. + num_iter (int): number of iterations used to compute the precise + statistics. + """ + self._logger = logging.getLogger(__name__) + if len(get_bn_modules(model)) == 0: + self._logger.info( + "PreciseBN is disabled because model does not contain BN layers in training mode." + ) + self._disabled = True + return + + self._model = model + self._data_loader = data_loader + self._num_iter = num_iter + self._period = period + self._disabled = False + + self._data_iter = None + + def after_step(self): + next_iter = self.trainer.iter + 1 + is_final = next_iter == self.trainer.max_iter + if is_final or (self._period > 0 and next_iter % self._period == 0): + self.update_stats() + + def update_stats(self): + """ + Update the model with precise statistics. Users can manually call this method. + """ + if self._disabled: + return + + if self._data_iter is None: + self._data_iter = iter(self._data_loader) + + def data_loader(): + for num_iter in itertools.count(1): + if num_iter % 100 == 0: + self._logger.info( + "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) + ) + # This way we can reuse the same iterator + yield next(self._data_iter) + + with EventStorage(): # capture events in a new storage to discard them + self._logger.info( + "Running precise-BN for {} iterations... ".format(self._num_iter) + + "Note that this could produce different statistics every time." + ) + update_bn_stats(self._model, data_loader(), self._num_iter) + + +class TorchMemoryStats(HookBase): + """ + Writes pytorch's cuda memory statistics periodically. + """ + + def __init__(self, period=20, max_runs=10): + """ + Args: + period (int): Output stats each 'period' iterations + max_runs (int): Stop the logging after 'max_runs' + """ + + self._logger = logging.getLogger(__name__) + self._period = period + self._max_runs = max_runs + self._runs = 0 + + def after_step(self): + if self._runs > self._max_runs: + return + + if (self.trainer.iter + 1) % self._period == 0 or ( + self.trainer.iter == self.trainer.max_iter - 1 + ): + if torch.cuda.is_available(): + max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0 + reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0 + max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0 + + self._logger.info( + ( + " iter: {} " + " max_reserved_mem: {:.0f}MB " + " reserved_mem: {:.0f}MB " + " max_allocated_mem: {:.0f}MB " + " allocated_mem: {:.0f}MB " + ).format( + self.trainer.iter, + max_reserved_mb, + reserved_mb, + max_allocated_mb, + allocated_mb, + ) + ) + + self._runs += 1 + if self._runs == self._max_runs: + mem_summary = torch.cuda.memory_summary() + self._logger.info("\n" + mem_summary) + + torch.cuda.reset_peak_memory_stats() diff --git a/data_processing/detectron2/detectron2/engine/launch.py b/data_processing/detectron2/detectron2/engine/launch.py new file mode 100644 index 0000000..7052c50 --- /dev/null +++ b/data_processing/detectron2/detectron2/engine/launch.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from datetime import timedelta +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from detectron2.utils import comm + +__all__ = ["DEFAULT_TIMEOUT", "launch"] + +DEFAULT_TIMEOUT = timedelta(minutes=30) + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch( + main_func, + # Should be num_processes_per_machine, but kept for compatibility. + num_gpus_per_machine, + num_machines=1, + machine_rank=0, + dist_url=None, + args=(), + timeout=DEFAULT_TIMEOUT, +): + """ + Launch multi-process or distributed training. + This function must be called on all machines involved in the training. + It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. + + Args: + main_func: a function that will be called by `main_func(*args)` + num_gpus_per_machine (int): number of processes per machine. When + using GPUs, this should be the number of GPUs. + num_machines (int): the total number of machines + machine_rank (int): the rank of this machine + dist_url (str): url to connect to for distributed jobs, including protocol + e.g. "tcp://127.0.0.1:8686". + Can be set to "auto" to automatically select a free port on localhost + timeout (timedelta): timeout of the distributed workers + args (tuple): arguments passed to main_func + """ + world_size = num_machines * num_gpus_per_machine + if world_size > 1: + # https://github.com/pytorch/pytorch/pull/14391 + # TODO prctl in spawned processes + + if dist_url == "auto": + assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + if num_machines > 1 and dist_url.startswith("file://"): + logger = logging.getLogger(__name__) + logger.warning( + "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" + ) + + mp.start_processes( + _distributed_worker, + nprocs=num_gpus_per_machine, + args=( + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + args, + timeout, + ), + daemon=False, + ) + else: + main_func(*args) + + +def _distributed_worker( + local_rank, + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + args, + timeout=DEFAULT_TIMEOUT, +): + has_gpu = torch.cuda.is_available() + if has_gpu: + assert num_gpus_per_machine <= torch.cuda.device_count() + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL" if has_gpu else "GLOO", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + timeout=timeout, + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error("Process group URL: {}".format(dist_url)) + raise e + + # Setup the local process group. + comm.create_local_process_group(num_gpus_per_machine) + if has_gpu: + torch.cuda.set_device(local_rank) + + # synchronize is needed here to prevent a possible timeout after calling init_process_group + # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 + comm.synchronize() + + main_func(*args) diff --git a/data_processing/detectron2/detectron2/engine/train_loop.py b/data_processing/detectron2/detectron2/engine/train_loop.py new file mode 100644 index 0000000..2f6b96d --- /dev/null +++ b/data_processing/detectron2/detectron2/engine/train_loop.py @@ -0,0 +1,528 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +import concurrent.futures +import logging +import numpy as np +import time +import weakref +from typing import List, Mapping, Optional +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +import detectron2.utils.comm as comm +from detectron2.utils.events import EventStorage, get_event_storage +from detectron2.utils.logger import _log_api_usage + +__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"] + + +class HookBase: + """ + Base class for hooks that can be registered with :class:`TrainerBase`. + + Each hook can implement 4 methods. The way they are called is demonstrated + in the following snippet: + :: + hook.before_train() + for iter in range(start_iter, max_iter): + hook.before_step() + trainer.run_step() + hook.after_step() + iter += 1 + hook.after_train() + + Notes: + 1. In the hook method, users can access ``self.trainer`` to access more + properties about the context (e.g., model, current iteration, or config + if using :class:`DefaultTrainer`). + + 2. A hook that does something in :meth:`before_step` can often be + implemented equivalently in :meth:`after_step`. + If the hook takes non-trivial time, it is strongly recommended to + implement the hook in :meth:`after_step` instead of :meth:`before_step`. + The convention is that :meth:`before_step` should only take negligible time. + + Following this convention will allow hooks that do care about the difference + between :meth:`before_step` and :meth:`after_step` (e.g., timer) to + function properly. + + """ + + trainer: "TrainerBase" = None + """ + A weak reference to the trainer object. Set by the trainer when the hook is registered. + """ + + def before_train(self): + """ + Called before the first iteration. + """ + pass + + def after_train(self): + """ + Called after the last iteration. + """ + pass + + def before_step(self): + """ + Called before each iteration. + """ + pass + + def after_backward(self): + """ + Called after the backward pass of each iteration. + """ + pass + + def after_step(self): + """ + Called after each iteration. + """ + pass + + def state_dict(self): + """ + Hooks are stateless by default, but can be made checkpointable by + implementing `state_dict` and `load_state_dict`. + """ + return {} + + +class TrainerBase: + """ + Base class for iterative trainer with hooks. + + The only assumption we made here is: the training runs in a loop. + A subclass can implement what the loop is. + We made no assumptions about the existence of dataloader, optimizer, model, etc. + + Attributes: + iter(int): the current iteration. + + start_iter(int): The iteration to start with. + By convention the minimum possible value is 0. + + max_iter(int): The iteration to end training. + + storage(EventStorage): An EventStorage that's opened during the course of training. + """ + + def __init__(self) -> None: + self._hooks: List[HookBase] = [] + self.iter: int = 0 + self.start_iter: int = 0 + self.max_iter: int + self.storage: EventStorage + _log_api_usage("trainer." + self.__class__.__name__) + + def register_hooks(self, hooks: List[Optional[HookBase]]) -> None: + """ + Register hooks to the trainer. The hooks are executed in the order + they are registered. + + Args: + hooks (list[Optional[HookBase]]): list of hooks + """ + hooks = [h for h in hooks if h is not None] + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self._hooks.extend(hooks) + + def train(self, start_iter: int, max_iter: int): + """ + Args: + start_iter, max_iter (int): See docs above + """ + logger = logging.getLogger(__name__) + logger.info("Starting training from iteration {}".format(start_iter)) + + self.iter = self.start_iter = start_iter + self.max_iter = max_iter + + with EventStorage(start_iter) as self.storage: + try: + self.before_train() + for self.iter in range(start_iter, max_iter): + self.before_step() + self.run_step() + self.after_step() + # self.iter == max_iter can be used by `after_train` to + # tell whether the training successfully finished or failed + # due to exceptions. + self.iter += 1 + except Exception: + logger.exception("Exception during training:") + raise + finally: + self.after_train() + + def before_train(self): + for h in self._hooks: + h.before_train() + + def after_train(self): + self.storage.iter = self.iter + for h in self._hooks: + h.after_train() + + def before_step(self): + # Maintain the invariant that storage.iter == trainer.iter + # for the entire execution of each step + self.storage.iter = self.iter + + for h in self._hooks: + h.before_step() + + def after_backward(self): + for h in self._hooks: + h.after_backward() + + def after_step(self): + for h in self._hooks: + h.after_step() + + def run_step(self): + raise NotImplementedError + + def state_dict(self): + ret = {"iteration": self.iter} + hooks_state = {} + for h in self._hooks: + sd = h.state_dict() + if sd: + name = type(h).__qualname__ + if name in hooks_state: + # TODO handle repetitive stateful hooks + continue + hooks_state[name] = sd + if hooks_state: + ret["hooks"] = hooks_state + return ret + + def load_state_dict(self, state_dict): + logger = logging.getLogger(__name__) + self.iter = state_dict["iteration"] + for key, value in state_dict.get("hooks", {}).items(): + for h in self._hooks: + try: + name = type(h).__qualname__ + except AttributeError: + continue + if name == key: + h.load_state_dict(value) + break + else: + logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.") + + +class SimpleTrainer(TrainerBase): + """ + A simple trainer for the most common type of task: + single-cost single-optimizer single-data-source iterative optimization, + optionally using data-parallelism. + It assumes that every step, you: + + 1. Compute the loss with a data from the data_loader. + 2. Compute the gradients with the above loss. + 3. Update the model with the optimizer. + + All other tasks during training (checkpointing, logging, evaluation, LR schedule) + are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. + + If you want to do anything fancier than this, + either subclass TrainerBase and implement your own `run_step`, + or write your own training loop. + """ + + def __init__( + self, + model, + data_loader, + optimizer, + gather_metric_period=1, + zero_grad_before_forward=False, + async_write_metrics=False, + ): + """ + Args: + model: a torch Module. Takes a data from data_loader and returns a + dict of losses. + data_loader: an iterable. Contains data to be used to call model. + optimizer: a torch optimizer. + gather_metric_period: an int. Every gather_metric_period iterations + the metrics are gathered from all the ranks to rank 0 and logged. + zero_grad_before_forward: whether to zero the gradients before the forward. + async_write_metrics: bool. If True, then write metrics asynchronously to improve + training speed + """ + super().__init__() + + """ + We set the model to training mode in the trainer. + However it's valid to train a model that's in eval mode. + If you want your model (or a submodule of it) to behave + like evaluation during training, you can overwrite its train() method. + """ + model.train() + + self.model = model + self.data_loader = data_loader + # to access the data loader iterator, call `self._data_loader_iter` + self._data_loader_iter_obj = None + self.optimizer = optimizer + self.gather_metric_period = gather_metric_period + self.zero_grad_before_forward = zero_grad_before_forward + self.async_write_metrics = async_write_metrics + # create a thread pool that can execute non critical logic in run_step asynchronically + # use only 1 worker so tasks will be executred in order of submitting. + self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def run_step(self): + """ + Implement the standard training logic described above. + """ + assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" + start = time.perf_counter() + """ + If you want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + if self.zero_grad_before_forward: + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + + """ + If you want to do something with the losses, you can wrap the model. + """ + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + if not self.zero_grad_before_forward: + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + losses.backward() + + self.after_backward() + + if self.async_write_metrics: + # write metrics asynchronically + self.concurrent_executor.submit( + self._write_metrics, loss_dict, data_time, iter=self.iter + ) + else: + self._write_metrics(loss_dict, data_time) + + """ + If you need gradient clipping/scaling or other processing, you can + wrap the optimizer with your custom `step()` method. But it is + suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 + """ + self.optimizer.step() + + @property + def _data_loader_iter(self): + # only create the data loader iterator when it is used + if self._data_loader_iter_obj is None: + self._data_loader_iter_obj = iter(self.data_loader) + return self._data_loader_iter_obj + + def reset_data_loader(self, data_loader_builder): + """ + Delete and replace the current data loader with a new one, which will be created + by calling `data_loader_builder` (without argument). + """ + del self.data_loader + data_loader = data_loader_builder() + self.data_loader = data_loader + self._data_loader_iter_obj = None + + def _write_metrics( + self, + loss_dict: Mapping[str, torch.Tensor], + data_time: float, + prefix: str = "", + iter: Optional[int] = None, + ) -> None: + logger = logging.getLogger(__name__) + + iter = self.iter if iter is None else iter + if (iter + 1) % self.gather_metric_period == 0: + try: + SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix) + except Exception: + logger.exception("Exception in writing metrics: ") + raise + + @staticmethod + def write_metrics( + loss_dict: Mapping[str, torch.Tensor], + data_time: float, + cur_iter: int, + prefix: str = "", + ) -> None: + """ + Args: + loss_dict (dict): dict of scalar losses + data_time (float): time taken by the dataloader iteration + prefix (str): prefix for logging keys + """ + metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()} + metrics_dict["data_time"] = data_time + + # Gather metrics among all workers for logging + # This assumes we do DDP-style training, which is currently the only + # supported method in detectron2. + all_metrics_dict = comm.gather(metrics_dict) + + if comm.is_main_process(): + storage = get_event_storage() + + # data_time among workers can have high variance. The actual latency + # caused by data_time is the maximum among workers. + data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) + storage.put_scalar("data_time", data_time, cur_iter=cur_iter) + + # average the rest metrics + metrics_dict = { + k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() + } + total_losses_reduced = sum(metrics_dict.values()) + if not np.isfinite(total_losses_reduced): + raise FloatingPointError( + f"Loss became infinite or NaN at iteration={cur_iter}!\n" + f"loss_dict = {metrics_dict}" + ) + + storage.put_scalar( + "{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter + ) + if len(metrics_dict) > 1: + storage.put_scalars(cur_iter=cur_iter, **metrics_dict) + + def state_dict(self): + ret = super().state_dict() + ret["optimizer"] = self.optimizer.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.optimizer.load_state_dict(state_dict["optimizer"]) + + def after_train(self): + super().after_train() + self.concurrent_executor.shutdown(wait=True) + + +class AMPTrainer(SimpleTrainer): + """ + Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision + in the training loop. + """ + + def __init__( + self, + model, + data_loader, + optimizer, + gather_metric_period=1, + zero_grad_before_forward=False, + grad_scaler=None, + precision: torch.dtype = torch.float16, + log_grad_scaler: bool = False, + async_write_metrics=False, + ): + """ + Args: + model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward, + async_write_metrics: same as in :class:`SimpleTrainer`. + grad_scaler: torch GradScaler to automatically scale gradients. + precision: torch.dtype as the target precision to cast to in computations + """ + unsupported = "AMPTrainer does not support single-process multi-device training!" + if isinstance(model, DistributedDataParallel): + assert not (model.device_ids and len(model.device_ids) > 1), unsupported + assert not isinstance(model, DataParallel), unsupported + + super().__init__( + model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward + ) + + if grad_scaler is None: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() + self.grad_scaler = grad_scaler + self.precision = precision + self.log_grad_scaler = log_grad_scaler + + def run_step(self): + """ + Implement the AMP training logic. + """ + assert self.model.training, "[AMPTrainer] model was changed to eval mode!" + assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast + + start = time.perf_counter() + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + if self.zero_grad_before_forward: + self.optimizer.zero_grad() + with autocast(dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + if not self.zero_grad_before_forward: + self.optimizer.zero_grad() + + self.grad_scaler.scale(losses).backward() + + if self.log_grad_scaler: + storage = get_event_storage() + storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale()) + + self.after_backward() + + if self.async_write_metrics: + # write metrics asynchronically + self.concurrent_executor.submit( + self._write_metrics, loss_dict, data_time, iter=self.iter + ) + else: + self._write_metrics(loss_dict, data_time) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + def state_dict(self): + ret = super().state_dict() + ret["grad_scaler"] = self.grad_scaler.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) diff --git a/data_processing/detectron2/detectron2/evaluation/__init__.py b/data_processing/detectron2/detectron2/evaluation/__init__.py new file mode 100644 index 0000000..d96609e --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator +from .coco_evaluation import COCOEvaluator +from .rotated_coco_evaluation import RotatedCOCOEvaluator +from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset +from .lvis_evaluation import LVISEvaluator +from .panoptic_evaluation import COCOPanopticEvaluator +from .pascal_voc_evaluation import PascalVOCDetectionEvaluator +from .sem_seg_evaluation import SemSegEvaluator +from .testing import print_csv_format, verify_results + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/evaluation/cityscapes_evaluation.py b/data_processing/detectron2/detectron2/evaluation/cityscapes_evaluation.py new file mode 100644 index 0000000..9cc7888 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/cityscapes_evaluation.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import glob +import logging +import numpy as np +import os +import tempfile +from collections import OrderedDict +import torch +from PIL import Image + +from detectron2.data import MetadataCatalog +from detectron2.utils import comm +from detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + + +class CityscapesEvaluator(DatasetEvaluator): + """ + Base class for evaluation using cityscapes API. + """ + + def __init__(self, dataset_name): + """ + Args: + dataset_name (str): the name of the dataset. + It must have the following metadata associated with it: + "thing_classes", "gt_dir". + """ + self._metadata = MetadataCatalog.get(dataset_name) + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + def reset(self): + self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_") + self._temp_dir = self._working_dir.name + # All workers will write to the same results directory + # TODO this does not work in distributed training + assert ( + comm.get_local_size() == comm.get_world_size() + ), "CityscapesEvaluator currently do not work with multiple machines." + self._temp_dir = comm.all_gather(self._temp_dir)[0] + if self._temp_dir != self._working_dir.name: + self._working_dir.cleanup() + self._logger.info( + "Writing cityscapes results to temporary directory {} ...".format(self._temp_dir) + ) + + +class CityscapesInstanceEvaluator(CityscapesEvaluator): + """ + Evaluate instance segmentation results on cityscapes dataset using cityscapes API. + + Note: + * It does not work in multi-machine distributed training. + * It contains a synchronization, therefore has to be used on all ranks. + * Only the main process runs evaluation. + """ + + def process(self, inputs, outputs): + from cityscapesscripts.helpers.labels import name2label + + for input, output in zip(inputs, outputs): + file_name = input["file_name"] + basename = os.path.splitext(os.path.basename(file_name))[0] + pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt") + + if "instances" in output: + output = output["instances"].to(self._cpu_device) + num_instances = len(output) + with open(pred_txt, "w") as fout: + for i in range(num_instances): + pred_class = output.pred_classes[i] + classes = self._metadata.thing_classes[pred_class] + class_id = name2label[classes].id + score = output.scores[i] + mask = output.pred_masks[i].numpy().astype("uint8") + png_filename = os.path.join( + self._temp_dir, basename + "_{}_{}.png".format(i, classes) + ) + + Image.fromarray(mask * 255).save(png_filename) + fout.write( + "{} {} {}\n".format(os.path.basename(png_filename), class_id, score) + ) + else: + # Cityscapes requires a prediction file for every ground truth image. + with open(pred_txt, "w") as fout: + pass + + def evaluate(self): + """ + Returns: + dict: has a key "segm", whose value is a dict of "AP" and "AP50". + """ + comm.synchronize() + if comm.get_rank() > 0: + return + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval + + self._logger.info("Evaluating results under {} ...".format(self._temp_dir)) + + # set some global states in cityscapes evaluation API, before evaluating + cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir) + cityscapes_eval.args.predictionWalk = None + cityscapes_eval.args.JSONOutput = False + cityscapes_eval.args.colorized = False + cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json") + + # These lines are adopted from + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa + gt_dir = PathManager.get_local_path(self._metadata.gt_dir) + groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png")) + assert len( + groundTruthImgList + ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format( + cityscapes_eval.args.groundTruthSearch + ) + predictionImgList = [] + for gt in groundTruthImgList: + predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args)) + results = cityscapes_eval.evaluateImgLists( + predictionImgList, groundTruthImgList, cityscapes_eval.args + )["averages"] + + ret = OrderedDict() + ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100} + self._working_dir.cleanup() + return ret + + +class CityscapesSemSegEvaluator(CityscapesEvaluator): + """ + Evaluate semantic segmentation results on cityscapes dataset using cityscapes API. + + Note: + * It does not work in multi-machine distributed training. + * It contains a synchronization, therefore has to be used on all ranks. + * Only the main process runs evaluation. + """ + + def process(self, inputs, outputs): + from cityscapesscripts.helpers.labels import trainId2label + + for input, output in zip(inputs, outputs): + file_name = input["file_name"] + basename = os.path.splitext(os.path.basename(file_name))[0] + pred_filename = os.path.join(self._temp_dir, basename + "_pred.png") + + output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy() + pred = 255 * np.ones(output.shape, dtype=np.uint8) + for train_id, label in trainId2label.items(): + if label.ignoreInEval: + continue + pred[output == train_id] = label.id + Image.fromarray(pred).save(pred_filename) + + def evaluate(self): + comm.synchronize() + if comm.get_rank() > 0: + return + # Load the Cityscapes eval script *after* setting the required env var, + # since the script reads CITYSCAPES_DATASET into global variables at load time. + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval + + self._logger.info("Evaluating results under {} ...".format(self._temp_dir)) + + # set some global states in cityscapes evaluation API, before evaluating + cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir) + cityscapes_eval.args.predictionWalk = None + cityscapes_eval.args.JSONOutput = False + cityscapes_eval.args.colorized = False + + # These lines are adopted from + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa + gt_dir = PathManager.get_local_path(self._metadata.gt_dir) + groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png")) + assert len( + groundTruthImgList + ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format( + cityscapes_eval.args.groundTruthSearch + ) + predictionImgList = [] + for gt in groundTruthImgList: + predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt)) + results = cityscapes_eval.evaluateImgLists( + predictionImgList, groundTruthImgList, cityscapes_eval.args + ) + ret = OrderedDict() + ret["sem_seg"] = { + "IoU": 100.0 * results["averageScoreClasses"], + "iIoU": 100.0 * results["averageScoreInstClasses"], + "IoU_sup": 100.0 * results["averageScoreCategories"], + "iIoU_sup": 100.0 * results["averageScoreInstCategories"], + } + self._working_dir.cleanup() + return ret diff --git a/data_processing/detectron2/detectron2/evaluation/coco_evaluation.py b/data_processing/detectron2/detectron2/evaluation/coco_evaluation.py new file mode 100644 index 0000000..fe8142c --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/coco_evaluation.py @@ -0,0 +1,722 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import io +import itertools +import json +import logging +import numpy as np +import os +import pickle +from collections import OrderedDict +import pycocotools.mask as mask_util +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from tabulate import tabulate + +import detectron2.utils.comm as comm +from detectron2.config import CfgNode +from detectron2.data import MetadataCatalog +from detectron2.data.datasets.coco import convert_to_coco_json +from detectron2.structures import Boxes, BoxMode, pairwise_iou +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import create_small_table + +from .evaluator import DatasetEvaluator + +try: + from detectron2.evaluation.fast_eval_api import COCOeval_opt +except ImportError: + COCOeval_opt = COCOeval + + +class COCOEvaluator(DatasetEvaluator): + """ + Evaluate AR for object proposals, AP for instance detection/segmentation, AP + for keypoint detection outputs using COCO's metrics. + See http://cocodataset.org/#detection-eval and + http://cocodataset.org/#keypoints-eval to understand its metrics. + The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means + the metric cannot be computed (e.g. due to no predictions made). + + In addition to COCO, this evaluator is able to support any bounding box detection, + instance segmentation, or keypoint detection dataset. + """ + + def __init__( + self, + dataset_name, + tasks=None, + distributed=True, + output_dir=None, + *, + max_dets_per_image=None, + use_fast_impl=True, + kpt_oks_sigmas=(), + allow_cached_coco=True, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + It must have either the following corresponding metadata: + + "json_file": the path to the COCO format annotation + + Or it must be in detectron2's standard dataset format + so it can be converted to COCO format automatically. + tasks (tuple[str]): tasks that can be evaluated under the given + configuration. A task is one of "bbox", "segm", "keypoints". + By default, will infer this automatically from predictions. + distributed (True): if True, will collect results from all ranks and run evaluation + in the main process. + Otherwise, will only evaluate the results in the current process. + output_dir (str): optional, an output directory to dump all + results predicted on the dataset. The dump contains two files: + + 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and + contains all the results in the format they are produced by the model. + 2. "coco_instances_results.json" a json file in COCO's result format. + max_dets_per_image (int): limit on the maximum number of detections per image. + By default in COCO, this limit is to 100, but this can be customized + to be greater, as is needed in evaluation metrics AP fixed and AP pool + (see https://arxiv.org/pdf/2102.01066.pdf) + This doesn't affect keypoint evaluation. + use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP. + Although the results should be very close to the official implementation in COCO + API, it is still recommended to compute results with the official API for use in + papers. The faster implementation also uses more RAM. + kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS. + See http://cocodataset.org/#keypoints-eval + When empty, it will use the defaults in COCO. + Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. + allow_cached_coco (bool): Whether to use cached coco json from previous validation + runs. You should set this to False if you need to use different validation data. + Defaults to True. + """ + self._logger = logging.getLogger(__name__) + self._distributed = distributed + self._output_dir = output_dir + + if use_fast_impl and (COCOeval_opt is COCOeval): + self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.") + use_fast_impl = False + self._use_fast_impl = use_fast_impl + + # COCOeval requires the limit on the number of detections per image (maxDets) to be a list + # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the + # 3rd element (100) is used as the limit on the number of detections per image when + # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval, + # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults. + if max_dets_per_image is None: + max_dets_per_image = [1, 10, 100] + else: + max_dets_per_image = [1, 10, max_dets_per_image] + self._max_dets_per_image = max_dets_per_image + + if tasks is not None and isinstance(tasks, CfgNode): + kpt_oks_sigmas = ( + tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas + ) + self._logger.warn( + "COCO Evaluator instantiated using config, this is deprecated behavior." + " Please pass in explicit arguments instead." + ) + self._tasks = None # Infering it from predictions should be better + else: + self._tasks = tasks + + self._cpu_device = torch.device("cpu") + + self._metadata = MetadataCatalog.get(dataset_name) + if not hasattr(self._metadata, "json_file"): + if output_dir is None: + raise ValueError( + "output_dir must be provided to COCOEvaluator " + "for datasets not in COCO format." + ) + self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...") + + cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json") + self._metadata.json_file = cache_path + convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco) + + json_file = PathManager.get_local_path(self._metadata.json_file) + with contextlib.redirect_stdout(io.StringIO()): + self._coco_api = COCO(json_file) + + # Test set json files do not contain annotations (evaluation must be + # performed using the COCO evaluation server). + self._do_evaluation = "annotations" in self._coco_api.dataset + if self._do_evaluation: + self._kpt_oks_sigmas = kpt_oks_sigmas + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + if len(prediction) > 1: + self._predictions.append(prediction) + + def evaluate(self, img_ids=None): + """ + Args: + img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset + """ + if self._distributed: + comm.synchronize() + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) + + if not comm.is_main_process(): + return {} + else: + predictions = self._predictions + + if len(predictions) == 0: + self._logger.warning("[COCOEvaluator] Did not receive valid predictions.") + return {} + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "instances_predictions.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(predictions, f) + + self._results = OrderedDict() + if "proposals" in predictions[0]: + self._eval_box_proposals(predictions) + if "instances" in predictions[0]: + self._eval_predictions(predictions, img_ids=img_ids) + # Copy so the caller can do whatever with results + return copy.deepcopy(self._results) + + def _tasks_from_predictions(self, predictions): + """ + Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions. + """ + tasks = {"bbox"} + for pred in predictions: + if "segmentation" in pred: + tasks.add("segm") + if "keypoints" in pred: + tasks.add("keypoints") + return sorted(tasks) + + def _eval_predictions(self, predictions, img_ids=None): + """ + Evaluate predictions. Fill self._results with the metrics of the tasks. + """ + self._logger.info("Preparing results for COCO format ...") + coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) + tasks = self._tasks or self._tasks_from_predictions(coco_results) + + # unmap the category ids for COCO + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id + all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) + num_classes = len(all_contiguous_ids) + assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 + + reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} + for result in coco_results: + category_id = result["category_id"] + assert category_id < num_classes, ( + f"A prediction has class={category_id}, " + f"but the dataset only has {num_classes} classes and " + f"predicted class id should be in [0, {num_classes - 1}]." + ) + result["category_id"] = reverse_id_mapping[category_id] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "coco_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(coco_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info( + "Evaluating predictions with {} COCO API...".format( + "unofficial" if self._use_fast_impl else "official" + ) + ) + for task in sorted(tasks): + assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" + coco_eval = ( + _evaluate_predictions_on_coco( + self._coco_api, + coco_results, + task, + kpt_oks_sigmas=self._kpt_oks_sigmas, + cocoeval_fn=COCOeval_opt if self._use_fast_impl else COCOeval, + img_ids=img_ids, + max_dets_per_image=self._max_dets_per_image, + ) + if len(coco_results) > 0 + else None # cocoapi does not handle empty results very well + ) + + res = self._derive_coco_results( + coco_eval, task, class_names=self._metadata.get("thing_classes") + ) + self._results[task] = res + + def _eval_box_proposals(self, predictions): + """ + Evaluate the box proposals in predictions. + Fill self._results with the metrics for "box_proposals" task. + """ + if self._output_dir: + # Saving generated box proposals to file. + # Predicted box_proposals are in XYXY_ABS mode. + bbox_mode = BoxMode.XYXY_ABS.value + ids, boxes, objectness_logits = [], [], [] + for prediction in predictions: + ids.append(prediction["image_id"]) + boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy()) + objectness_logits.append(prediction["proposals"].objectness_logits.numpy()) + + proposal_data = { + "boxes": boxes, + "objectness_logits": objectness_logits, + "ids": ids, + "bbox_mode": bbox_mode, + } + with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f: + pickle.dump(proposal_data, f) + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating bbox proposals ...") + res = {} + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit) + key = "AR{}@{:d}".format(suffix, limit) + res[key] = float(stats["ar"].item() * 100) + self._logger.info("Proposal metrics: \n" + create_small_table(res)) + self._results["box_proposals"] = res + + def _derive_coco_results(self, coco_eval, iou_type, class_names=None): + """ + Derive the desired score numbers from summarized COCOeval. + + Args: + coco_eval (None or COCOEval): None represents no predictions from model. + iou_type (str): + class_names (None or list[str]): if provided, will use it to predict + per-category AP. + + Returns: + a dict of {metric name: score} + """ + + metrics = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], + }[iou_type] + + if coco_eval is None: + self._logger.warn("No predictions from the model!") + return {metric: float("nan") for metric in metrics} + + # the standard metrics + results = { + metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan") + for idx, metric in enumerate(metrics) + } + self._logger.info( + "Evaluation results for {}: \n".format(iou_type) + create_small_table(results) + ) + if not np.isfinite(sum(results.values())): + self._logger.info("Some metrics cannot be computed and is shown as NaN.") + + if class_names is None or len(class_names) <= 1: + return results + # Compute per-category AP + # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa + precisions = coco_eval.eval["precision"] + # precision has dims (iou, recall, cls, area range, max dets) + assert len(class_names) == precisions.shape[2] + + results_per_category = [] + for idx, name in enumerate(class_names): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float("nan") + results_per_category.append(("{}".format(name), float(ap * 100))) + + # tabulate it + N_COLS = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP"] * (N_COLS // 2), + numalign="left", + ) + self._logger.info("Per-category {} AP: \n".format(iou_type) + table) + + results.update({"AP-" + name: ap for name, ap in results_per_category}) + return results + + +def instances_to_coco_json(instances, img_id): + """ + Dump an "Instances" object to a COCO-format json that's used for evaluation. + + Args: + instances (Instances): + img_id (int): the image id + + Returns: + list[dict]: list of json annotations in COCO format. + """ + num_instance = len(instances) + if num_instance == 0: + return [] + + boxes = instances.pred_boxes.tensor.numpy() + boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + boxes = boxes.tolist() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + + has_mask = instances.has("pred_masks") + if has_mask: + # use RLE to encode the masks, because they are too large and takes memory + # since this evaluator stores outputs of the entire dataset + rles = [ + mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] + for mask in instances.pred_masks + ] + for rle in rles: + # "counts" is an array encoded by mask_util as a byte-stream. Python3's + # json writer which always produces strings cannot serialize a bytestream + # unless you decode it. Thankfully, utf-8 works out (which is also what + # the pycocotools/_mask.pyx does). + rle["counts"] = rle["counts"].decode("utf-8") + + has_keypoints = instances.has("pred_keypoints") + if has_keypoints: + keypoints = instances.pred_keypoints + + results = [] + for k in range(num_instance): + result = { + "image_id": img_id, + "category_id": classes[k], + "bbox": boxes[k], + "score": scores[k], + } + if has_mask: + result["segmentation"] = rles[k] + if has_keypoints: + # In COCO annotations, + # keypoints coordinates are pixel indices. + # However our predictions are floating point coordinates. + # Therefore we subtract 0.5 to be consistent with the annotation format. + # This is the inverse of data loading logic in `datasets/coco.py`. + keypoints[k][:, :2] -= 0.5 + result["keypoints"] = keypoints[k].flatten().tolist() + results.append(result) + return results + + +# inspired from Detectron: +# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa +def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None): + """ + Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official COCO API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0**2, 1e5**2], # all + [0**2, 32**2], # small + [32**2, 96**2], # medium + [96**2, 1e5**2], # large + [96**2, 128**2], # 96-128 + [128**2, 256**2], # 128-256 + [256**2, 512**2], # 256-512 + [512**2, 1e5**2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for prediction_dict in dataset_predictions: + predictions = prediction_dict["proposals"] + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + inds = predictions.objectness_logits.sort(descending=True)[1] + predictions = predictions[inds] + + ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"]) + anno = coco_api.loadAnns(ann_ids) + gt_boxes = [ + BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) + for obj in anno + if obj["iscrowd"] == 0 + ] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = Boxes(gt_boxes) + gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0]) + + if len(gt_boxes) == 0 or len(predictions) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if limit is not None and len(predictions) > limit: + predictions = predictions[:limit] + + overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(predictions), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = ( + torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32) + ) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def _evaluate_predictions_on_coco( + coco_gt, + coco_results, + iou_type, + kpt_oks_sigmas=None, + cocoeval_fn=COCOeval_opt, + img_ids=None, + max_dets_per_image=None, +): + """ + Evaluate the coco results using COCOEval API. + """ + assert len(coco_results) > 0 + + if iou_type == "segm": + coco_results = copy.deepcopy(coco_results) + # When evaluating mask AP, if the results contain bbox, cocoapi will + # use the box area as the area of the instance, instead of the mask area. + # This leads to a different definition of small/medium/large. + # We remove the bbox field to let mask AP use mask area. + for c in coco_results: + c.pop("bbox", None) + + coco_dt = coco_gt.loadRes(coco_results) + coco_eval = cocoeval_fn(coco_gt, coco_dt, iou_type) + # For COCO, the default max_dets_per_image is [1, 10, 100]. + if max_dets_per_image is None: + max_dets_per_image = [1, 10, 100] # Default from COCOEval + else: + assert ( + len(max_dets_per_image) >= 3 + ), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3" + # In the case that user supplies a custom input for max_dets_per_image, + # apply COCOevalMaxDets to evaluate AP with the custom input. + if max_dets_per_image[2] != 100: + coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type) + if iou_type != "keypoints": + coco_eval.params.maxDets = max_dets_per_image + + if img_ids is not None: + coco_eval.params.imgIds = img_ids + + if iou_type == "keypoints": + # Use the COCO default keypoint OKS sigmas unless overrides are specified + if kpt_oks_sigmas: + assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "pycocotools is too old!" + coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas) + # COCOAPI requires every detection and every gt to have keypoints, so + # we just take the first entry from both + num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3 + num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3 + num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas) + assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, ( + f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. " + f"Ground truth contains {num_keypoints_gt} keypoints. " + f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. " + "They have to agree with each other. For meaning of OKS, please refer to " + "http://cocodataset.org/#keypoints-eval." + ) + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + return coco_eval + + +class COCOevalMaxDets(COCOeval): + """ + Modified version of COCOeval for evaluating AP with a custom + maxDets (by default for COCO, maxDets is 100) + """ + + def summarize(self): + """ + Compute and display summary metrics for evaluation results given + a custom value for max_dets_per_image + """ + + def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100): + p = self.params + iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval["precision"] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval["recall"] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) + return mean_s + + def _summarizeDets(): + stats = np.zeros((12,)) + # Evaluate AP using the custom limit on maximum detections per image + stats[0] = _summarize(1, maxDets=self.params.maxDets[2]) + stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2]) + stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2]) + stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2]) + stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2]) + stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) + stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) + stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) + stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2]) + stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2]) + return stats + + def _summarizeKps(): + stats = np.zeros((10,)) + stats[0] = _summarize(1, maxDets=20) + stats[1] = _summarize(1, maxDets=20, iouThr=0.5) + stats[2] = _summarize(1, maxDets=20, iouThr=0.75) + stats[3] = _summarize(1, maxDets=20, areaRng="medium") + stats[4] = _summarize(1, maxDets=20, areaRng="large") + stats[5] = _summarize(0, maxDets=20) + stats[6] = _summarize(0, maxDets=20, iouThr=0.5) + stats[7] = _summarize(0, maxDets=20, iouThr=0.75) + stats[8] = _summarize(0, maxDets=20, areaRng="medium") + stats[9] = _summarize(0, maxDets=20, areaRng="large") + return stats + + if not self.eval: + raise Exception("Please run accumulate() first") + iouType = self.params.iouType + if iouType == "segm" or iouType == "bbox": + summarize = _summarizeDets + elif iouType == "keypoints": + summarize = _summarizeKps + self.stats = summarize() + + def __str__(self): + self.summarize() diff --git a/data_processing/detectron2/detectron2/evaluation/evaluator.py b/data_processing/detectron2/detectron2/evaluation/evaluator.py new file mode 100644 index 0000000..baf9960 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/evaluator.py @@ -0,0 +1,224 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import datetime +import logging +import time +from collections import OrderedDict, abc +from contextlib import ExitStack, contextmanager +from typing import List, Union +import torch +from torch import nn + +from detectron2.utils.comm import get_world_size, is_main_process +from detectron2.utils.logger import log_every_n_seconds + + +class DatasetEvaluator: + """ + Base class for a dataset evaluator. + + The function :func:`inference_on_dataset` runs the model over + all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs. + + This class will accumulate information of the inputs/outputs (by :meth:`process`), + and produce evaluation results in the end (by :meth:`evaluate`). + """ + + def reset(self): + """ + Preparation for a new round of evaluation. + Should be called before starting a round of evaluation. + """ + pass + + def process(self, inputs, outputs): + """ + Process the pair of inputs and outputs. + If they contain batches, the pairs can be consumed one-by-one using `zip`: + + .. code-block:: python + + for input_, output in zip(inputs, outputs): + # do evaluation on single input/output pair + ... + + Args: + inputs (list): the inputs that's used to call the model. + outputs (list): the return value of `model(inputs)` + """ + pass + + def evaluate(self): + """ + Evaluate/summarize the performance, after processing all input/output pairs. + + Returns: + dict: + A new evaluator class can return a dict of arbitrary format + as long as the user can process the results. + In our train_net.py, we expect the following format: + + * key: the name of the task (e.g., bbox) + * value: a dict of {metric name: score}, e.g.: {"AP50": 80} + """ + pass + + +class DatasetEvaluators(DatasetEvaluator): + """ + Wrapper class to combine multiple :class:`DatasetEvaluator` instances. + + This class dispatches every evaluation call to + all of its :class:`DatasetEvaluator`. + """ + + def __init__(self, evaluators): + """ + Args: + evaluators (list): the evaluators to combine. + """ + super().__init__() + self._evaluators = evaluators + + def reset(self): + for evaluator in self._evaluators: + evaluator.reset() + + def process(self, inputs, outputs): + for evaluator in self._evaluators: + evaluator.process(inputs, outputs) + + def evaluate(self): + results = OrderedDict() + for evaluator in self._evaluators: + result = evaluator.evaluate() + if is_main_process() and result is not None: + for k, v in result.items(): + assert ( + k not in results + ), "Different evaluators produce results with the same key {}".format(k) + results[k] = v + return results + + +def inference_on_dataset( + model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None] +): + """ + Run model on the data_loader and evaluate the metrics with evaluator. + Also benchmark the inference speed of `model.__call__` accurately. + The model will be used in eval mode. + + Args: + model (callable): a callable which takes an object from + `data_loader` and returns some outputs. + + If it's an nn.Module, it will be temporarily set to `eval` mode. + If you wish to evaluate a model in `training` mode instead, you can + wrap the given model and override its behavior of `.eval()` and `.train()`. + data_loader: an iterable object with a length. + The elements it generates will be the inputs to the model. + evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark, + but don't want to do any evaluation. + + Returns: + The return value of `evaluator.evaluate()` + """ + num_devices = get_world_size() + logger = logging.getLogger(__name__) + logger.info("Start inference on {} batches".format(len(data_loader))) + + total = len(data_loader) # inference data loader must have a fixed length + if evaluator is None: + # create a no-op evaluator + evaluator = DatasetEvaluators([]) + if isinstance(evaluator, abc.MutableSequence): + evaluator = DatasetEvaluators(evaluator) + evaluator.reset() + + num_warmup = min(5, total - 1) + start_time = time.perf_counter() + total_data_time = 0 + total_compute_time = 0 + total_eval_time = 0 + with ExitStack() as stack: + if isinstance(model, nn.Module): + stack.enter_context(inference_context(model)) + stack.enter_context(torch.no_grad()) + + start_data_time = time.perf_counter() + for idx, inputs in enumerate(data_loader): + total_data_time += time.perf_counter() - start_data_time + if idx == num_warmup: + start_time = time.perf_counter() + total_data_time = 0 + total_compute_time = 0 + total_eval_time = 0 + + start_compute_time = time.perf_counter() + outputs = model(inputs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_compute_time += time.perf_counter() - start_compute_time + + start_eval_time = time.perf_counter() + evaluator.process(inputs, outputs) + total_eval_time += time.perf_counter() - start_eval_time + + iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) + data_seconds_per_iter = total_data_time / iters_after_start + compute_seconds_per_iter = total_compute_time / iters_after_start + eval_seconds_per_iter = total_eval_time / iters_after_start + total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start + if idx >= num_warmup * 2 or compute_seconds_per_iter > 5: + eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1))) + log_every_n_seconds( + logging.INFO, + ( + f"Inference done {idx + 1}/{total}. " + f"Dataloading: {data_seconds_per_iter:.4f} s/iter. " + f"Inference: {compute_seconds_per_iter:.4f} s/iter. " + f"Eval: {eval_seconds_per_iter:.4f} s/iter. " + f"Total: {total_seconds_per_iter:.4f} s/iter. " + f"ETA={eta}" + ), + n=5, + ) + start_data_time = time.perf_counter() + + # Measure the time only for this worker (before the synchronization barrier) + total_time = time.perf_counter() - start_time + total_time_str = str(datetime.timedelta(seconds=total_time)) + # NOTE this format is parsed by grep + logger.info( + "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format( + total_time_str, total_time / (total - num_warmup), num_devices + ) + ) + total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) + logger.info( + "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format( + total_compute_time_str, total_compute_time / (total - num_warmup), num_devices + ) + ) + + results = evaluator.evaluate() + # An evaluator may return None when not in main process. + # Replace it by an empty dict instead to make it easier for downstream code to handle + if results is None: + results = {} + return results + + +@contextmanager +def inference_context(model): + """ + A context where the model is temporarily changed to eval mode, + and restored to previous mode afterwards. + + Args: + model: a torch Module + """ + training_mode = model.training + model.eval() + yield + model.train(training_mode) diff --git a/data_processing/detectron2/detectron2/evaluation/fast_eval_api.py b/data_processing/detectron2/detectron2/evaluation/fast_eval_api.py new file mode 100644 index 0000000..2eb202b --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/fast_eval_api.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import numpy as np +import time +from pycocotools.cocoeval import COCOeval + +from detectron2 import _C + +logger = logging.getLogger(__name__) + + +class COCOeval_opt(COCOeval): + """ + This is a slightly modified version of the original COCO API, where the functions evaluateImg() + and accumulate() are implemented in C++ to speedup evaluation + """ + + def evaluate(self): + """ + Run per image evaluation on given images and store results in self.evalImgs_cpp, a + datastructure that isn't readable from Python but is used by a c++ implementation of + accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure + self.evalImgs because this datastructure is a computational bottleneck. + :return: None + """ + tic = time.time() + + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + logger.info("Evaluate annotation type *{}*".format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() # bottleneck + + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds + } # bottleneck + + maxDet = p.maxDets[-1] + + # <<<< Beginning of code differences with original COCO API + def convert_instances_to_cpp(instances, is_det=False): + # Convert annotations for a list of instances in an image to a format that's fast + # to access in C++ + instances_cpp = [] + for instance in instances: + instance_cpp = _C.InstanceAnnotation( + int(instance["id"]), + instance["score"] if is_det else instance.get("score", 0.0), + instance["area"], + bool(instance.get("iscrowd", 0)), + bool(instance.get("ignore", 0)), + ) + instances_cpp.append(instance_cpp) + return instances_cpp + + # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++ + ground_truth_instances = [ + [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds] + for imgId in p.imgIds + ] + detected_instances = [ + [convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) for catId in p.catIds] + for imgId in p.imgIds + ] + ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds] + + if not p.useCats: + # For each image, flatten per-category lists into a single list + ground_truth_instances = [[[o for c in i for o in c]] for i in ground_truth_instances] + detected_instances = [[[o for c in i for o in c]] for i in detected_instances] + + # Call C++ implementation of self.evaluateImgs() + self._evalImgs_cpp = _C.COCOevalEvaluateImages( + p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances + ) + self._evalImgs = None + + self._paramsEval = copy.deepcopy(self.params) + toc = time.time() + logger.info("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic)) + # >>>> End of code differences with original COCO API + + def accumulate(self): + """ + Accumulate per image evaluation results and store the result in self.eval. Does not + support changing parameter settings from those used by self.evaluate() + """ + logger.info("Accumulating evaluation results...") + tic = time.time() + assert hasattr( + self, "_evalImgs_cpp" + ), "evaluate() must be called before accmulate() is called." + + self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp) + + # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections + self.eval["recall"] = np.array(self.eval["recall"]).reshape( + self.eval["counts"][:1] + self.eval["counts"][2:] + ) + + # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X + # num_area_ranges X num_max_detections + self.eval["precision"] = np.array(self.eval["precision"]).reshape(self.eval["counts"]) + self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"]) + toc = time.time() + logger.info("COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic)) diff --git a/data_processing/detectron2/detectron2/evaluation/lvis_evaluation.py b/data_processing/detectron2/detectron2/evaluation/lvis_evaluation.py new file mode 100644 index 0000000..6cc854a --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/lvis_evaluation.py @@ -0,0 +1,380 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import json +import logging +import os +import pickle +from collections import OrderedDict +import torch + +import detectron2.utils.comm as comm +from detectron2.config import CfgNode +from detectron2.data import MetadataCatalog +from detectron2.structures import Boxes, BoxMode, pairwise_iou +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import create_small_table + +from .coco_evaluation import instances_to_coco_json +from .evaluator import DatasetEvaluator + + +class LVISEvaluator(DatasetEvaluator): + """ + Evaluate object proposal and instance detection/segmentation outputs using + LVIS's metrics and evaluation API. + """ + + def __init__( + self, + dataset_name, + tasks=None, + distributed=True, + output_dir=None, + *, + max_dets_per_image=None, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + It must have the following corresponding metadata: + "json_file": the path to the LVIS format annotation + tasks (tuple[str]): tasks that can be evaluated under the given + configuration. A task is one of "bbox", "segm". + By default, will infer this automatically from predictions. + distributed (True): if True, will collect results from all ranks for evaluation. + Otherwise, will evaluate the results in the current process. + output_dir (str): optional, an output directory to dump results. + max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP + This limit, by default of the LVIS dataset, is 300. + """ + from lvis import LVIS + + self._logger = logging.getLogger(__name__) + + if tasks is not None and isinstance(tasks, CfgNode): + self._logger.warn( + "COCO Evaluator instantiated using config, this is deprecated behavior." + " Please pass in explicit arguments instead." + ) + self._tasks = None # Infering it from predictions should be better + else: + self._tasks = tasks + + self._distributed = distributed + self._output_dir = output_dir + self._max_dets_per_image = max_dets_per_image + + self._cpu_device = torch.device("cpu") + + self._metadata = MetadataCatalog.get(dataset_name) + json_file = PathManager.get_local_path(self._metadata.json_file) + self._lvis_api = LVIS(json_file) + # Test set json files do not contain annotations (evaluation must be + # performed using the LVIS evaluation server). + self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0 + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a LVIS model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + self._predictions.append(prediction) + + def evaluate(self): + if self._distributed: + comm.synchronize() + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) + + if not comm.is_main_process(): + return + else: + predictions = self._predictions + + if len(predictions) == 0: + self._logger.warning("[LVISEvaluator] Did not receive valid predictions.") + return {} + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "instances_predictions.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(predictions, f) + + self._results = OrderedDict() + if "proposals" in predictions[0]: + self._eval_box_proposals(predictions) + if "instances" in predictions[0]: + self._eval_predictions(predictions) + # Copy so the caller can do whatever with results + return copy.deepcopy(self._results) + + def _tasks_from_predictions(self, predictions): + for pred in predictions: + if "segmentation" in pred: + return ("bbox", "segm") + return ("bbox",) + + def _eval_predictions(self, predictions): + """ + Evaluate predictions. Fill self._results with the metrics of the tasks. + + Args: + predictions (list[dict]): list of outputs from the model + """ + self._logger.info("Preparing results in the LVIS format ...") + lvis_results = list(itertools.chain(*[x["instances"] for x in predictions])) + tasks = self._tasks or self._tasks_from_predictions(lvis_results) + + # LVIS evaluator can be used to evaluate results for COCO dataset categories. + # In this case `_metadata` variable will have a field with COCO-specific category mapping. + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + for result in lvis_results: + result["category_id"] = reverse_id_mapping[result["category_id"]] + else: + # unmap the category ids for LVIS (from 0-indexed to 1-indexed) + for result in lvis_results: + result["category_id"] += 1 + + if self._output_dir: + file_path = os.path.join(self._output_dir, "lvis_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(lvis_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating predictions ...") + for task in sorted(tasks): + res = _evaluate_predictions_on_lvis( + self._lvis_api, + lvis_results, + task, + max_dets_per_image=self._max_dets_per_image, + class_names=self._metadata.get("thing_classes"), + ) + self._results[task] = res + + def _eval_box_proposals(self, predictions): + """ + Evaluate the box proposals in predictions. + Fill self._results with the metrics for "box_proposals" task. + """ + if self._output_dir: + # Saving generated box proposals to file. + # Predicted box_proposals are in XYXY_ABS mode. + bbox_mode = BoxMode.XYXY_ABS.value + ids, boxes, objectness_logits = [], [], [] + for prediction in predictions: + ids.append(prediction["image_id"]) + boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy()) + objectness_logits.append(prediction["proposals"].objectness_logits.numpy()) + + proposal_data = { + "boxes": boxes, + "objectness_logits": objectness_logits, + "ids": ids, + "bbox_mode": bbox_mode, + } + with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f: + pickle.dump(proposal_data, f) + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating bbox proposals ...") + res = {} + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit) + key = "AR{}@{:d}".format(suffix, limit) + res[key] = float(stats["ar"].item() * 100) + self._logger.info("Proposal metrics: \n" + create_small_table(res)) + self._results["box_proposals"] = res + + +# inspired from Detectron: +# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa +def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None): + """ + Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official LVIS API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0**2, 1e5**2], # all + [0**2, 32**2], # small + [32**2, 96**2], # medium + [96**2, 1e5**2], # large + [96**2, 128**2], # 96-128 + [128**2, 256**2], # 128-256 + [256**2, 512**2], # 256-512 + [512**2, 1e5**2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for prediction_dict in dataset_predictions: + predictions = prediction_dict["proposals"] + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + inds = predictions.objectness_logits.sort(descending=True)[1] + predictions = predictions[inds] + + ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]]) + anno = lvis_api.load_anns(ann_ids) + gt_boxes = [ + BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno + ] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = Boxes(gt_boxes) + gt_areas = torch.as_tensor([obj["area"] for obj in anno]) + + if len(gt_boxes) == 0 or len(predictions) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if limit is not None and len(predictions) > limit: + predictions = predictions[:limit] + + overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(predictions), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = ( + torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32) + ) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def _evaluate_predictions_on_lvis( + lvis_gt, lvis_results, iou_type, max_dets_per_image=None, class_names=None +): + """ + Args: + iou_type (str): + max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP + This limit, by default of the LVIS dataset, is 300. + class_names (None or list[str]): if provided, will use it to predict + per-category AP. + + Returns: + a dict of {metric name: score} + """ + metrics = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], + }[iou_type] + + logger = logging.getLogger(__name__) + + if len(lvis_results) == 0: # TODO: check if needed + logger.warn("No predictions from the model!") + return {metric: float("nan") for metric in metrics} + + if iou_type == "segm": + lvis_results = copy.deepcopy(lvis_results) + # When evaluating mask AP, if the results contain bbox, LVIS API will + # use the box area as the area of the instance, instead of the mask area. + # This leads to a different definition of small/medium/large. + # We remove the bbox field to let mask AP use mask area. + for c in lvis_results: + c.pop("bbox", None) + + if max_dets_per_image is None: + max_dets_per_image = 300 # Default for LVIS dataset + + from lvis import LVISEval, LVISResults + + logger.info(f"Evaluating with max detections per image = {max_dets_per_image}") + lvis_results = LVISResults(lvis_gt, lvis_results, max_dets=max_dets_per_image) + lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type) + lvis_eval.run() + lvis_eval.print_results() + + # Pull the standard metrics from the LVIS results + results = lvis_eval.get_results() + results = {metric: float(results[metric] * 100) for metric in metrics} + logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results)) + return results diff --git a/data_processing/detectron2/detectron2/evaluation/panoptic_evaluation.py b/data_processing/detectron2/detectron2/evaluation/panoptic_evaluation.py new file mode 100644 index 0000000..9fb3462 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/panoptic_evaluation.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import io +import itertools +import json +import logging +import numpy as np +import os +import tempfile +from collections import OrderedDict +from typing import Optional +from PIL import Image +from tabulate import tabulate + +from detectron2.data import MetadataCatalog +from detectron2.utils import comm +from detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + +logger = logging.getLogger(__name__) + + +class COCOPanopticEvaluator(DatasetEvaluator): + """ + Evaluate Panoptic Quality metrics on COCO using PanopticAPI. + It saves panoptic segmentation prediction in `output_dir` + + It contains a synchronize call and has to be called from all workers. + """ + + def __init__(self, dataset_name: str, output_dir: Optional[str] = None): + """ + Args: + dataset_name: name of the dataset + output_dir: output directory to save results for evaluation. + """ + self._metadata = MetadataCatalog.get(dataset_name) + self._thing_contiguous_id_to_dataset_id = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + self._stuff_contiguous_id_to_dataset_id = { + v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items() + } + + self._output_dir = output_dir + if self._output_dir is not None: + PathManager.mkdirs(self._output_dir) + + def reset(self): + self._predictions = [] + + def _convert_category_id(self, segment_info): + isthing = segment_info.pop("isthing", None) + if isthing is None: + # the model produces panoptic category id directly. No more conversion needed + return segment_info + if isthing is True: + segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[ + segment_info["category_id"] + ] + else: + segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[ + segment_info["category_id"] + ] + return segment_info + + def process(self, inputs, outputs): + from panopticapi.utils import id2rgb + + for input, output in zip(inputs, outputs): + panoptic_img, segments_info = output["panoptic_seg"] + panoptic_img = panoptic_img.cpu().numpy() + if segments_info is None: + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label, and add 1 to panoptic_img since the official + # evaluation script uses 0 for VOID label. + label_divisor = self._metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_img): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = ( + pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values() + ) + segments_info.append( + { + "id": int(panoptic_label) + 1, + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + # Official evaluation script uses 0 for VOID label. + panoptic_img += 1 + + file_name = os.path.basename(input["file_name"]) + file_name_png = os.path.splitext(file_name)[0] + ".png" + with io.BytesIO() as out: + Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG") + segments_info = [self._convert_category_id(x) for x in segments_info] + self._predictions.append( + { + "image_id": input["image_id"], + "file_name": file_name_png, + "png_string": out.getvalue(), + "segments_info": segments_info, + } + ) + + def evaluate(self): + comm.synchronize() + + self._predictions = comm.gather(self._predictions) + self._predictions = list(itertools.chain(*self._predictions)) + if not comm.is_main_process(): + return + + # PanopticApi requires local files + gt_json = PathManager.get_local_path(self._metadata.panoptic_json) + gt_folder = PathManager.get_local_path(self._metadata.panoptic_root) + + with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir: + logger.info("Writing all panoptic predictions to {} ...".format(pred_dir)) + for p in self._predictions: + with open(os.path.join(pred_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + with open(gt_json, "r") as f: + json_data = json.load(f) + json_data["annotations"] = self._predictions + + output_dir = self._output_dir or pred_dir + predictions_json = os.path.join(output_dir, "predictions.json") + with PathManager.open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + + from panopticapi.evaluation import pq_compute + + with contextlib.redirect_stdout(io.StringIO()): + pq_res = pq_compute( + gt_json, + PathManager.get_local_path(predictions_json), + gt_folder=gt_folder, + pred_folder=pred_dir, + ) + + res = {} + res["PQ"] = 100 * pq_res["All"]["pq"] + res["SQ"] = 100 * pq_res["All"]["sq"] + res["RQ"] = 100 * pq_res["All"]["rq"] + res["PQ_th"] = 100 * pq_res["Things"]["pq"] + res["SQ_th"] = 100 * pq_res["Things"]["sq"] + res["RQ_th"] = 100 * pq_res["Things"]["rq"] + res["PQ_st"] = 100 * pq_res["Stuff"]["pq"] + res["SQ_st"] = 100 * pq_res["Stuff"]["sq"] + res["RQ_st"] = 100 * pq_res["Stuff"]["rq"] + + results = OrderedDict({"panoptic_seg": res}) + _print_panoptic_results(pq_res) + + return results + + +def _print_panoptic_results(pq_res): + headers = ["", "PQ", "SQ", "RQ", "#categories"] + data = [] + for name in ["All", "Things", "Stuff"]: + row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]] + data.append(row) + table = tabulate( + data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center" + ) + logger.info("Panoptic Evaluation Results:\n" + table) + + +if __name__ == "__main__": + from detectron2.utils.logger import setup_logger + + logger = setup_logger() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--gt-json") + parser.add_argument("--gt-dir") + parser.add_argument("--pred-json") + parser.add_argument("--pred-dir") + args = parser.parse_args() + + from panopticapi.evaluation import pq_compute + + with contextlib.redirect_stdout(io.StringIO()): + pq_res = pq_compute( + args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir + ) + _print_panoptic_results(pq_res) diff --git a/data_processing/detectron2/detectron2/evaluation/pascal_voc_evaluation.py b/data_processing/detectron2/detectron2/evaluation/pascal_voc_evaluation.py new file mode 100644 index 0000000..88bb42e --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/pascal_voc_evaluation.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +import os +import tempfile +import xml.etree.ElementTree as ET +from collections import OrderedDict, defaultdict +from functools import lru_cache +import torch + +from detectron2.data import MetadataCatalog +from detectron2.utils import comm +from detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + + +class PascalVOCDetectionEvaluator(DatasetEvaluator): + """ + Evaluate Pascal VOC style AP for Pascal VOC dataset. + It contains a synchronization, therefore has to be called from all ranks. + + Note that the concept of AP can be implemented in different ways and may not + produce identical results. This class mimics the implementation of the official + Pascal VOC Matlab API, and should produce similar but not identical results to the + official API. + """ + + def __init__(self, dataset_name): + """ + Args: + dataset_name (str): name of the dataset, e.g., "voc_2007_test" + """ + self._dataset_name = dataset_name + meta = MetadataCatalog.get(dataset_name) + + # Too many tiny files, download all to local for speed. + annotation_dir_local = PathManager.get_local_path( + os.path.join(meta.dirname, "Annotations/") + ) + self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml") + self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt") + self._class_names = meta.thing_classes + assert meta.year in [2007, 2012], meta.year + self._is_2007 = meta.year == 2007 + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + def reset(self): + self._predictions = defaultdict(list) # class name -> list of prediction strings + + def process(self, inputs, outputs): + for input, output in zip(inputs, outputs): + image_id = input["image_id"] + instances = output["instances"].to(self._cpu_device) + boxes = instances.pred_boxes.tensor.numpy() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + for box, score, cls in zip(boxes, scores, classes): + xmin, ymin, xmax, ymax = box + # The inverse of data loading logic in `datasets/pascal_voc.py` + xmin += 1 + ymin += 1 + self._predictions[cls].append( + f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}" + ) + + def evaluate(self): + """ + Returns: + dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". + """ + all_predictions = comm.gather(self._predictions, dst=0) + if not comm.is_main_process(): + return + predictions = defaultdict(list) + for predictions_per_rank in all_predictions: + for clsid, lines in predictions_per_rank.items(): + predictions[clsid].extend(lines) + del all_predictions + + self._logger.info( + "Evaluating {} using {} metric. " + "Note that results do not use the official Matlab API.".format( + self._dataset_name, 2007 if self._is_2007 else 2012 + ) + ) + + with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: + res_file_template = os.path.join(dirname, "{}.txt") + + aps = defaultdict(list) # iou -> ap per class + for cls_id, cls_name in enumerate(self._class_names): + lines = predictions.get(cls_id, [""]) + + with open(res_file_template.format(cls_name), "w") as f: + f.write("\n".join(lines)) + + for thresh in range(50, 100, 5): + rec, prec, ap = voc_eval( + res_file_template, + self._anno_file_template, + self._image_set_path, + cls_name, + ovthresh=thresh / 100.0, + use_07_metric=self._is_2007, + ) + aps[thresh].append(ap * 100) + + ret = OrderedDict() + mAP = {iou: np.mean(x) for iou, x in aps.items()} + ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} + return ret + + +############################################################################## +# +# Below code is modified from +# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py +# -------------------------------------------------------- +# Fast/er R-CNN +# Licensed under The MIT License [see LICENSE for details] +# Written by Bharath Hariharan +# -------------------------------------------------------- + +"""Python implementation of the PASCAL VOC devkit's AP evaluation code.""" + + +@lru_cache(maxsize=None) +def parse_rec(filename): + """Parse a PASCAL VOC xml file.""" + with PathManager.open(filename) as f: + tree = ET.parse(f) + objects = [] + for obj in tree.findall("object"): + obj_struct = {} + obj_struct["name"] = obj.find("name").text + obj_struct["pose"] = obj.find("pose").text + obj_struct["truncated"] = int(obj.find("truncated").text) + obj_struct["difficult"] = int(obj.find("difficult").text) + bbox = obj.find("bndbox") + obj_struct["bbox"] = [ + int(bbox.find("xmin").text), + int(bbox.find("ymin").text), + int(bbox.find("xmax").text), + int(bbox.find("ymax").text), + ] + objects.append(obj_struct) + + return objects + + +def voc_ap(rec, prec, use_07_metric=False): + """Compute VOC AP given precision and recall. If use_07_metric is true, uses + the VOC 07 11-point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0.0 + for t in np.arange(0.0, 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11.0 + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.0], rec, [1.0])) + mpre = np.concatenate(([0.0], prec, [0.0])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False): + """rec, prec, ap = voc_eval(detpath, + annopath, + imagesetfile, + classname, + [ovthresh], + [use_07_metric]) + + Top level function that does the PASCAL VOC evaluation. + + detpath: Path to detections + detpath.format(classname) should produce the detection results file. + annopath: Path to annotations + annopath.format(imagename) should be the xml annotations file. + imagesetfile: Text file containing the list of images, one image per line. + classname: Category name (duh) + [ovthresh]: Overlap threshold (default = 0.5) + [use_07_metric]: Whether to use VOC07's 11 point AP computation + (default False) + """ + # assumes detections are in detpath.format(classname) + # assumes annotations are in annopath.format(imagename) + # assumes imagesetfile is a text file with each line an image name + + # first load gt + # read list of images + with PathManager.open(imagesetfile, "r") as f: + lines = f.readlines() + imagenames = [x.strip() for x in lines] + + # load annots + recs = {} + for imagename in imagenames: + recs[imagename] = parse_rec(annopath.format(imagename)) + + # extract gt objects for this class + class_recs = {} + npos = 0 + for imagename in imagenames: + R = [obj for obj in recs[imagename] if obj["name"] == classname] + bbox = np.array([x["bbox"] for x in R]) + difficult = np.array([x["difficult"] for x in R]).astype(bool) + # difficult = np.array([False for x in R]).astype(bool) # treat all "difficult" as GT + det = [False] * len(R) + npos = npos + sum(~difficult) + class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det} + + # read dets + detfile = detpath.format(classname) + with open(detfile, "r") as f: + lines = f.readlines() + + splitlines = [x.strip().split(" ") for x in lines] + image_ids = [x[0] for x in splitlines] + confidence = np.array([float(x[1]) for x in splitlines]) + BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4) + + # sort by confidence + sorted_ind = np.argsort(-confidence) + BB = BB[sorted_ind, :] + image_ids = [image_ids[x] for x in sorted_ind] + + # go down dets and mark TPs and FPs + nd = len(image_ids) + tp = np.zeros(nd) + fp = np.zeros(nd) + for d in range(nd): + R = class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R["bbox"].astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1.0, 0.0) + ih = np.maximum(iymax - iymin + 1.0, 0.0) + inters = iw * ih + + # union + uni = ( + (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) + + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) + - inters + ) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > ovthresh: + if not R["difficult"][jmax]: + if not R["det"][jmax]: + tp[d] = 1.0 + R["det"][jmax] = 1 + else: + fp[d] = 1.0 + else: + fp[d] = 1.0 + + # compute precision recall + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, use_07_metric) + + return rec, prec, ap diff --git a/data_processing/detectron2/detectron2/evaluation/rotated_coco_evaluation.py b/data_processing/detectron2/detectron2/evaluation/rotated_coco_evaluation.py new file mode 100644 index 0000000..ea6d1b3 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/rotated_coco_evaluation.py @@ -0,0 +1,207 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import json +import numpy as np +import os +import torch +from pycocotools.cocoeval import COCOeval, maskUtils + +from detectron2.structures import BoxMode, RotatedBoxes, pairwise_iou_rotated +from detectron2.utils.file_io import PathManager + +from .coco_evaluation import COCOEvaluator + + +class RotatedCOCOeval(COCOeval): + @staticmethod + def is_rotated(box_list): + if type(box_list) == np.ndarray: + return box_list.shape[1] == 5 + elif type(box_list) == list: + if box_list == []: # cannot decide the box_dim + return False + return np.all( + np.array( + [ + (len(obj) == 5) and ((type(obj) == list) or (type(obj) == np.ndarray)) + for obj in box_list + ] + ) + ) + return False + + @staticmethod + def boxlist_to_tensor(boxlist, output_box_dim): + if type(boxlist) == np.ndarray: + box_tensor = torch.from_numpy(boxlist) + elif type(boxlist) == list: + if boxlist == []: + return torch.zeros((0, output_box_dim), dtype=torch.float32) + else: + box_tensor = torch.FloatTensor(boxlist) + else: + raise Exception("Unrecognized boxlist type") + + input_box_dim = box_tensor.shape[1] + if input_box_dim != output_box_dim: + if input_box_dim == 4 and output_box_dim == 5: + box_tensor = BoxMode.convert(box_tensor, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS) + else: + raise Exception( + "Unable to convert from {}-dim box to {}-dim box".format( + input_box_dim, output_box_dim + ) + ) + return box_tensor + + def compute_iou_dt_gt(self, dt, gt, is_crowd): + if self.is_rotated(dt) or self.is_rotated(gt): + # TODO: take is_crowd into consideration + assert all(c == 0 for c in is_crowd) + dt = RotatedBoxes(self.boxlist_to_tensor(dt, output_box_dim=5)) + gt = RotatedBoxes(self.boxlist_to_tensor(gt, output_box_dim=5)) + return pairwise_iou_rotated(dt, gt) + else: + # This is the same as the classical COCO evaluation + return maskUtils.iou(dt, gt, is_crowd) + + def computeIoU(self, imgId, catId): + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + assert p.iouType == "bbox", "unsupported iouType for iou computation" + + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + + # compute iou between each dt and gt region + iscrowd = [int(o["iscrowd"]) for o in gt] + + # Note: this function is copied from cocoeval.py in cocoapi + # and the major difference is here. + ious = self.compute_iou_dt_gt(d, g, iscrowd) + return ious + + +class RotatedCOCOEvaluator(COCOEvaluator): + """ + Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs, + with rotated boxes support. + Note: this uses IOU only and does not consider angle differences. + """ + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + + prediction["instances"] = self.instances_to_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + self._predictions.append(prediction) + + def instances_to_json(self, instances, img_id): + num_instance = len(instances) + if num_instance == 0: + return [] + + boxes = instances.pred_boxes.tensor.numpy() + if boxes.shape[1] == 4: + boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + boxes = boxes.tolist() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + + results = [] + for k in range(num_instance): + result = { + "image_id": img_id, + "category_id": classes[k], + "bbox": boxes[k], + "score": scores[k], + } + + results.append(result) + return results + + def _eval_predictions(self, predictions, img_ids=None): # img_ids: unused + """ + Evaluate predictions on the given tasks. + Fill self._results with the metrics of the tasks. + """ + self._logger.info("Preparing results for COCO format ...") + coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) + + # unmap the category ids for COCO + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + for result in coco_results: + result["category_id"] = reverse_id_mapping[result["category_id"]] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "coco_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(coco_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating predictions ...") + + assert self._tasks is None or set(self._tasks) == { + "bbox" + }, "[RotatedCOCOEvaluator] Only bbox evaluation is supported" + coco_eval = ( + self._evaluate_predictions_on_coco(self._coco_api, coco_results) + if len(coco_results) > 0 + else None # cocoapi does not handle empty results very well + ) + + task = "bbox" + res = self._derive_coco_results( + coco_eval, task, class_names=self._metadata.get("thing_classes") + ) + self._results[task] = res + + def _evaluate_predictions_on_coco(self, coco_gt, coco_results): + """ + Evaluate the coco results using COCOEval API. + """ + assert len(coco_results) > 0 + + coco_dt = coco_gt.loadRes(coco_results) + + # Only bbox is supported for now + coco_eval = RotatedCOCOeval(coco_gt, coco_dt, iouType="bbox") + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + return coco_eval diff --git a/data_processing/detectron2/detectron2/evaluation/sem_seg_evaluation.py b/data_processing/detectron2/detectron2/evaluation/sem_seg_evaluation.py new file mode 100644 index 0000000..3735de6 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/sem_seg_evaluation.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import json +import logging +import numpy as np +import os +from collections import OrderedDict +from typing import Optional, Union +import pycocotools.mask as mask_util +import torch +from PIL import Image + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.utils.comm import all_gather, is_main_process, synchronize +from detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + +_CV2_IMPORTED = True +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + _CV2_IMPORTED = False + + +def load_image_into_numpy_array( + filename: str, + copy: bool = False, + dtype: Optional[Union[np.dtype, str]] = None, +) -> np.ndarray: + with PathManager.open(filename, "rb") as f: + array = np.array(Image.open(f), copy=copy, dtype=dtype) + return array + + +class SemSegEvaluator(DatasetEvaluator): + """ + Evaluate semantic segmentation metrics. + """ + + def __init__( + self, + dataset_name, + distributed=True, + output_dir=None, + *, + sem_seg_loading_fn=load_image_into_numpy_array, + num_classes=None, + ignore_label=None, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + distributed (bool): if True, will collect results from all ranks for evaluation. + Otherwise, will evaluate the results in the current process. + output_dir (str): an output directory to dump results. + sem_seg_loading_fn: function to read sem seg file and load into numpy array. + Default provided, but projects can customize. + num_classes, ignore_label: deprecated argument + """ + self._logger = logging.getLogger(__name__) + if num_classes is not None: + self._logger.warn( + "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata." + ) + if ignore_label is not None: + self._logger.warn( + "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata." + ) + self._dataset_name = dataset_name + self._distributed = distributed + self._output_dir = output_dir + + self._cpu_device = torch.device("cpu") + + self.input_file_to_gt_file = { + dataset_record["file_name"]: dataset_record["sem_seg_file_name"] + for dataset_record in DatasetCatalog.get(dataset_name) + } + + meta = MetadataCatalog.get(dataset_name) + # Dict that maps contiguous training ids to COCO category ids + try: + c2d = meta.stuff_dataset_id_to_contiguous_id + self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()} + except AttributeError: + self._contiguous_id_to_dataset_id = None + self._class_names = meta.stuff_classes + self.sem_seg_loading_fn = sem_seg_loading_fn + self._num_classes = len(meta.stuff_classes) + if num_classes is not None: + assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}" + self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label + + # This is because cv2.erode did not work for int datatype. Only works for uint8. + self._compute_boundary_iou = True + if not _CV2_IMPORTED: + self._compute_boundary_iou = False + self._logger.warn( + """Boundary IoU calculation requires OpenCV. B-IoU metrics are + not going to be computed because OpenCV is not available to import.""" + ) + if self._num_classes >= np.iinfo(np.uint8).max: + self._compute_boundary_iou = False + self._logger.warn( + f"""SemSegEvaluator(num_classes) is more than supported value for Boundary IoU calculation! + B-IoU metrics are not going to be computed. Max allowed value (exclusive) + for num_classes for calculating Boundary IoU is {np.iinfo(np.uint8).max}. + The number of classes of dataset {self._dataset_name} is {self._num_classes}""" + ) + + def reset(self): + self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64) + self._b_conf_matrix = np.zeros( + (self._num_classes + 1, self._num_classes + 1), dtype=np.int64 + ) + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a model. + It is a list of dicts. Each dict corresponds to an image and + contains keys like "height", "width", "file_name". + outputs: the outputs of a model. It is either list of semantic segmentation predictions + (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic + segmentation prediction in the same format. + """ + for input, output in zip(inputs, outputs): + output = output["sem_seg"].argmax(dim=0).to(self._cpu_device) + pred = np.array(output, dtype=np.int) + gt_filename = self.input_file_to_gt_file[input["file_name"]] + gt = self.sem_seg_loading_fn(gt_filename, dtype=np.int) + + gt[gt == self._ignore_label] = self._num_classes + + self._conf_matrix += np.bincount( + (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), + minlength=self._conf_matrix.size, + ).reshape(self._conf_matrix.shape) + + if self._compute_boundary_iou: + b_gt = self._mask_to_boundary(gt.astype(np.uint8)) + b_pred = self._mask_to_boundary(pred.astype(np.uint8)) + + self._b_conf_matrix += np.bincount( + (self._num_classes + 1) * b_pred.reshape(-1) + b_gt.reshape(-1), + minlength=self._conf_matrix.size, + ).reshape(self._conf_matrix.shape) + + self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"])) + + def evaluate(self): + """ + Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): + + * Mean intersection-over-union averaged across classes (mIoU) + * Frequency Weighted IoU (fwIoU) + * Mean pixel accuracy averaged across classes (mACC) + * Pixel Accuracy (pACC) + """ + if self._distributed: + synchronize() + conf_matrix_list = all_gather(self._conf_matrix) + b_conf_matrix_list = all_gather(self._b_conf_matrix) + self._predictions = all_gather(self._predictions) + self._predictions = list(itertools.chain(*self._predictions)) + if not is_main_process(): + return + + self._conf_matrix = np.zeros_like(self._conf_matrix) + for conf_matrix in conf_matrix_list: + self._conf_matrix += conf_matrix + + self._b_conf_matrix = np.zeros_like(self._b_conf_matrix) + for b_conf_matrix in b_conf_matrix_list: + self._b_conf_matrix += b_conf_matrix + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "sem_seg_predictions.json") + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(self._predictions)) + + acc = np.full(self._num_classes, np.nan, dtype=np.float) + iou = np.full(self._num_classes, np.nan, dtype=np.float) + tp = self._conf_matrix.diagonal()[:-1].astype(np.float) + pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) + class_weights = pos_gt / np.sum(pos_gt) + pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) + acc_valid = pos_gt > 0 + acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] + union = pos_gt + pos_pred - tp + iou_valid = np.logical_and(acc_valid, union > 0) + iou[iou_valid] = tp[iou_valid] / union[iou_valid] + macc = np.sum(acc[acc_valid]) / np.sum(acc_valid) + miou = np.sum(iou[iou_valid]) / np.sum(iou_valid) + fiou = np.sum(iou[iou_valid] * class_weights[iou_valid]) + pacc = np.sum(tp) / np.sum(pos_gt) + + if self._compute_boundary_iou: + b_iou = np.full(self._num_classes, np.nan, dtype=np.float) + b_tp = self._b_conf_matrix.diagonal()[:-1].astype(np.float) + b_pos_gt = np.sum(self._b_conf_matrix[:-1, :-1], axis=0).astype(np.float) + b_pos_pred = np.sum(self._b_conf_matrix[:-1, :-1], axis=1).astype(np.float) + b_union = b_pos_gt + b_pos_pred - b_tp + b_iou_valid = b_union > 0 + b_iou[b_iou_valid] = b_tp[b_iou_valid] / b_union[b_iou_valid] + + res = {} + res["mIoU"] = 100 * miou + res["fwIoU"] = 100 * fiou + for i, name in enumerate(self._class_names): + res[f"IoU-{name}"] = 100 * iou[i] + if self._compute_boundary_iou: + res[f"BoundaryIoU-{name}"] = 100 * b_iou[i] + res[f"min(IoU, B-Iou)-{name}"] = 100 * min(iou[i], b_iou[i]) + res["mACC"] = 100 * macc + res["pACC"] = 100 * pacc + for i, name in enumerate(self._class_names): + res[f"ACC-{name}"] = 100 * acc[i] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(res, f) + results = OrderedDict({"sem_seg": res}) + self._logger.info(results) + return results + + def encode_json_sem_seg(self, sem_seg, input_file_name): + """ + Convert semantic segmentation to COCO stuff format with segments encoded as RLEs. + See http://cocodataset.org/#format-results + """ + json_list = [] + for label in np.unique(sem_seg): + if self._contiguous_id_to_dataset_id is not None: + assert ( + label in self._contiguous_id_to_dataset_id + ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name) + dataset_id = self._contiguous_id_to_dataset_id[label] + else: + dataset_id = int(label) + mask = (sem_seg == label).astype(np.uint8) + mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0] + mask_rle["counts"] = mask_rle["counts"].decode("utf-8") + json_list.append( + {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle} + ) + return json_list + + def _mask_to_boundary(self, mask: np.ndarray, dilation_ratio=0.02): + assert mask.ndim == 2, "mask_to_boundary expects a 2-dimensional image" + h, w = mask.shape + diag_len = np.sqrt(h**2 + w**2) + dilation = max(1, int(round(dilation_ratio * diag_len))) + kernel = np.ones((3, 3), dtype=np.uint8) + + padded_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + eroded_mask_with_padding = cv2.erode(padded_mask, kernel, iterations=dilation) + eroded_mask = eroded_mask_with_padding[1:-1, 1:-1] + boundary = mask - eroded_mask + return boundary diff --git a/data_processing/detectron2/detectron2/evaluation/testing.py b/data_processing/detectron2/detectron2/evaluation/testing.py new file mode 100644 index 0000000..9e5ae62 --- /dev/null +++ b/data_processing/detectron2/detectron2/evaluation/testing.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +import pprint +import sys +from collections.abc import Mapping + + +def print_csv_format(results): + """ + Print main metrics in a format similar to Detectron, + so that they are easy to copypaste into a spreadsheet. + + Args: + results (OrderedDict[dict]): task_name -> {metric -> score} + unordered dict can also be printed, but in arbitrary order + """ + assert isinstance(results, Mapping) or not len(results), results + logger = logging.getLogger(__name__) + for task, res in results.items(): + if isinstance(res, Mapping): + # Don't print "AP-category" metrics since they are usually not tracked. + important_res = [(k, v) for k, v in res.items() if "-" not in k] + logger.info("copypaste: Task: {}".format(task)) + logger.info("copypaste: " + ",".join([k[0] for k in important_res])) + logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res])) + else: + logger.info(f"copypaste: {task}={res}") + + +def verify_results(cfg, results): + """ + Args: + results (OrderedDict[dict]): task_name -> {metric -> score} + + Returns: + bool: whether the verification succeeds or not + """ + expected_results = cfg.TEST.EXPECTED_RESULTS + if not len(expected_results): + return True + + ok = True + for task, metric, expected, tolerance in expected_results: + actual = results[task].get(metric, None) + if actual is None: + ok = False + continue + if not np.isfinite(actual): + ok = False + continue + diff = abs(actual - expected) + if diff > tolerance: + ok = False + + logger = logging.getLogger(__name__) + if not ok: + logger.error("Result verification failed!") + logger.error("Expected Results: " + str(expected_results)) + logger.error("Actual Results: " + pprint.pformat(results)) + + sys.exit(1) + else: + logger.info("Results verification passed.") + return ok + + +def flatten_results_dict(results): + """ + Expand a hierarchical dict of scalars into a flat dict of scalars. + If results[k1][k2][k3] = v, the returned dict will have the entry + {"k1/k2/k3": v}. + + Args: + results (dict): + """ + r = {} + for k, v in results.items(): + if isinstance(v, Mapping): + v = flatten_results_dict(v) + for kk, vv in v.items(): + r[k + "/" + kk] = vv + else: + r[k] = v + return r diff --git a/data_processing/detectron2/detectron2/export/README.md b/data_processing/detectron2/detectron2/export/README.md new file mode 100644 index 0000000..c86ff62 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/README.md @@ -0,0 +1,15 @@ + +This directory contains code to prepare a detectron2 model for deployment. +Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format. + +Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage. + + +### Acknowledgements + +Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools. + +Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who +help export Detectron2 models to TorchScript. + +Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX. diff --git a/data_processing/detectron2/detectron2/export/__init__.py b/data_processing/detectron2/detectron2/export/__init__.py new file mode 100644 index 0000000..5a58758 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- + +import warnings + +from .flatten import TracingAdapter +from .torchscript import dump_torchscript_IR, scripting_with_instances + +try: + from caffe2.proto import caffe2_pb2 as _tmp + from caffe2.python import core + + # caffe2 is optional +except ImportError: + pass +else: + from .api import * + + +# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported +STABLE_ONNX_OPSET_VERSION = 11 + + +def add_export_config(cfg): + warnings.warn( + "add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning + ) + return cfg + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/export/api.py b/data_processing/detectron2/detectron2/export/api.py new file mode 100644 index 0000000..1a272fe --- /dev/null +++ b/data_processing/detectron2/detectron2/export/api.py @@ -0,0 +1,230 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import os +import torch +from caffe2.proto import caffe2_pb2 +from torch import nn + +from detectron2.config import CfgNode +from detectron2.utils.file_io import PathManager + +from .caffe2_inference import ProtobufDetectionModel +from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format +from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph + +__all__ = [ + "Caffe2Model", + "Caffe2Tracer", +] + + +class Caffe2Tracer: + """ + Make a detectron2 model traceable with Caffe2 operators. + This class creates a traceable version of a detectron2 model which: + + 1. Rewrite parts of the model using ops in Caffe2. Note that some ops do + not have GPU implementation in Caffe2. + 2. Remove post-processing and only produce raw layer outputs + + After making a traceable model, the class provide methods to export such a + model to different deployment formats. + Exported graph produced by this class take two input tensors: + + 1. (1, C, H, W) float "data" which is an image (usually in [0, 255]). + (H, W) often has to be padded to multiple of 32 (depend on the model + architecture). + 2. 1x3 float "im_info", each row of which is (height, width, 1.0). + Height and width are true image shapes before padding. + + The class currently only supports models using builtin meta architectures. + Batch inference is not supported, and contributions are welcome. + """ + + def __init__(self, cfg: CfgNode, model: nn.Module, inputs): + """ + Args: + cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model. + model (nn.Module): An original pytorch model. Must be among a few official models + in detectron2 that can be converted to become caffe2-compatible automatically. + Weights have to be already loaded to this model. + inputs: sample inputs that the given model takes for inference. + Will be used to trace the model. For most models, random inputs with + no detected objects will not work as they lead to wrong traces. + """ + assert isinstance(cfg, CfgNode), cfg + assert isinstance(model, torch.nn.Module), type(model) + + # TODO make it support custom models, by passing in c2 model directly + C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE] + self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model)) + self.inputs = inputs + self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs) + + def export_caffe2(self): + """ + Export the model to Caffe2's protobuf format. + The returned object can be saved with its :meth:`.save_protobuf()` method. + The result can be loaded and executed using Caffe2 runtime. + + Returns: + :class:`Caffe2Model` + """ + from .caffe2_export import export_caffe2_detection_model + + predict_net, init_net = export_caffe2_detection_model( + self.traceable_model, self.traceable_inputs + ) + return Caffe2Model(predict_net, init_net) + + def export_onnx(self): + """ + Export the model to ONNX format. + Note that the exported model contains custom ops only available in caffe2, therefore it + cannot be directly executed by other runtime (such as onnxruntime or TensorRT). + Post-processing or transformation passes may be applied on the model to accommodate + different runtimes, but we currently do not provide support for them. + + Returns: + onnx.ModelProto: an onnx model. + """ + from .caffe2_export import export_onnx_model as export_onnx_model_impl + + return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,)) + + def export_torchscript(self): + """ + Export the model to a ``torch.jit.TracedModule`` by tracing. + The returned object can be saved to a file by ``.save()``. + + Returns: + torch.jit.TracedModule: a torch TracedModule + """ + logger = logging.getLogger(__name__) + logger.info("Tracing the model with torch.jit.trace ...") + with torch.no_grad(): + return torch.jit.trace(self.traceable_model, (self.traceable_inputs,)) + + +class Caffe2Model(nn.Module): + """ + A wrapper around the traced model in Caffe2's protobuf format. + The exported graph has different inputs/outputs from the original Pytorch + model, as explained in :class:`Caffe2Tracer`. This class wraps around the + exported graph to simulate the same interface as the original Pytorch model. + It also provides functions to save/load models in Caffe2's format.' + + Examples: + :: + c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2() + inputs = [{"image": img_tensor_CHW}] + outputs = c2_model(inputs) + orig_outputs = torch_model(inputs) + """ + + def __init__(self, predict_net, init_net): + super().__init__() + self.eval() # always in eval mode + self._predict_net = predict_net + self._init_net = init_net + self._predictor = None + + __init__.__HIDE_SPHINX_DOC__ = True + + @property + def predict_net(self): + """ + caffe2.core.Net: the underlying caffe2 predict net + """ + return self._predict_net + + @property + def init_net(self): + """ + caffe2.core.Net: the underlying caffe2 init net + """ + return self._init_net + + def save_protobuf(self, output_dir): + """ + Save the model as caffe2's protobuf format. + It saves the following files: + + * "model.pb": definition of the graph. Can be visualized with + tools like `netron `_. + * "model_init.pb": model parameters + * "model.pbtxt": human-readable definition of the graph. Not + needed for deployment. + + Args: + output_dir (str): the output directory to save protobuf files. + """ + logger = logging.getLogger(__name__) + logger.info("Saving model to {} ...".format(output_dir)) + if not PathManager.exists(output_dir): + PathManager.mkdirs(output_dir) + + with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f: + f.write(self._predict_net.SerializeToString()) + with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f: + f.write(str(self._predict_net)) + with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f: + f.write(self._init_net.SerializeToString()) + + def save_graph(self, output_file, inputs=None): + """ + Save the graph as SVG format. + + Args: + output_file (str): a SVG file + inputs: optional inputs given to the model. + If given, the inputs will be used to run the graph to record + shape of every tensor. The shape information will be + saved together with the graph. + """ + from .caffe2_export import run_and_save_graph + + if inputs is None: + save_graph(self._predict_net, output_file, op_only=False) + else: + size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0) + device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii") + inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device) + inputs = [x.cpu().numpy() for x in inputs] + run_and_save_graph(self._predict_net, self._init_net, inputs, output_file) + + @staticmethod + def load_protobuf(dir): + """ + Args: + dir (str): a directory used to save Caffe2Model with + :meth:`save_protobuf`. + The files "model.pb" and "model_init.pb" are needed. + + Returns: + Caffe2Model: the caffe2 model loaded from this directory. + """ + predict_net = caffe2_pb2.NetDef() + with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f: + predict_net.ParseFromString(f.read()) + + init_net = caffe2_pb2.NetDef() + with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f: + init_net.ParseFromString(f.read()) + + return Caffe2Model(predict_net, init_net) + + def __call__(self, inputs): + """ + An interface that wraps around a Caffe2 model and mimics detectron2's models' + input/output format. See details about the format at :doc:`/tutorials/models`. + This is used to compare the outputs of caffe2 model with its original torch model. + + Due to the extra conversion between Pytorch/Caffe2, this method is not meant for + benchmark. Because of the conversion, this method also has dependency + on detectron2 in order to convert to detectron2's output format. + """ + if self._predictor is None: + self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net) + return self._predictor(inputs) diff --git a/data_processing/detectron2/detectron2/export/c10.py b/data_processing/detectron2/detectron2/export/c10.py new file mode 100644 index 0000000..e9a3ee3 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/c10.py @@ -0,0 +1,571 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import math +from typing import Dict +import torch +import torch.nn.functional as F + +from detectron2.layers import ShapeSpec, cat +from detectron2.layers.roi_align_rotated import ROIAlignRotated +from detectron2.modeling import poolers +from detectron2.modeling.proposal_generator import rpn +from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference +from detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes + +from .shared import alias, to_device + + +""" +This file contains caffe2-compatible implementation of several detectron2 components. +""" + + +class Caffe2Boxes(Boxes): + """ + Representing a list of detectron2.structures.Boxes from minibatch, each box + is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector + (batch index + 5 coordinates) for RotatedBoxes. + """ + + def __init__(self, tensor): + assert isinstance(tensor, torch.Tensor) + assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size() + # TODO: make tensor immutable when dim is Nx5 for Boxes, + # and Nx6 for RotatedBoxes? + self.tensor = tensor + + +# TODO clean up this class, maybe just extend Instances +class InstancesList(object): + """ + Tensor representation of a list of Instances object for a batch of images. + + When dealing with a batch of images with Caffe2 ops, a list of bboxes + (instances) are usually represented by single Tensor with size + (sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is + for providing common functions to convert between these two representations. + """ + + def __init__(self, im_info, indices, extra_fields=None): + # [N, 3] -> (H, W, Scale) + self.im_info = im_info + # [N,] -> indice of batch to which the instance belongs + self.indices = indices + # [N, ...] + self.batch_extra_fields = extra_fields or {} + + self.image_size = self.im_info + + def get_fields(self): + """like `get_fields` in the Instances object, + but return each field in tensor representations""" + ret = {} + for k, v in self.batch_extra_fields.items(): + # if isinstance(v, torch.Tensor): + # tensor_rep = v + # elif isinstance(v, (Boxes, Keypoints)): + # tensor_rep = v.tensor + # else: + # raise ValueError("Can't find tensor representation for: {}".format()) + ret[k] = v + return ret + + def has(self, name): + return name in self.batch_extra_fields + + def set(self, name, value): + # len(tensor) is a bad practice that generates ONNX constants during tracing. + # Although not a problem for the `assert` statement below, torch ONNX exporter + # still raises a misleading warning as it does not this call comes from `assert` + if isinstance(value, Boxes): + data_len = value.tensor.shape[0] + elif isinstance(value, torch.Tensor): + data_len = value.shape[0] + else: + data_len = len(value) + if len(self.batch_extra_fields): + assert ( + len(self) == data_len + ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) + self.batch_extra_fields[name] = value + + def __getattr__(self, name): + if name not in self.batch_extra_fields: + raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) + return self.batch_extra_fields[name] + + def __len__(self): + return len(self.indices) + + def flatten(self): + ret = [] + for _, v in self.batch_extra_fields.items(): + if isinstance(v, (Boxes, Keypoints)): + ret.append(v.tensor) + else: + ret.append(v) + return ret + + @staticmethod + def to_d2_instances_list(instances_list): + """ + Convert InstancesList to List[Instances]. The input `instances_list` can + also be a List[Instances], in this case this method is a non-op. + """ + if not isinstance(instances_list, InstancesList): + assert all(isinstance(x, Instances) for x in instances_list) + return instances_list + + ret = [] + for i, info in enumerate(instances_list.im_info): + instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())])) + + ids = instances_list.indices == i + for k, v in instances_list.batch_extra_fields.items(): + if isinstance(v, torch.Tensor): + instances.set(k, v[ids]) + continue + elif isinstance(v, Boxes): + instances.set(k, v[ids, -4:]) + continue + + target_type, tensor_source = v + assert isinstance(tensor_source, torch.Tensor) + assert tensor_source.shape[0] == instances_list.indices.shape[0] + tensor_source = tensor_source[ids] + + if issubclass(target_type, Boxes): + instances.set(k, Boxes(tensor_source[:, -4:])) + elif issubclass(target_type, Keypoints): + instances.set(k, Keypoints(tensor_source)) + elif issubclass(target_type, torch.Tensor): + instances.set(k, tensor_source) + else: + raise ValueError("Can't handle targe type: {}".format(target_type)) + + ret.append(instances) + return ret + + +class Caffe2Compatible(object): + """ + A model can inherit this class to indicate that it can be traced and deployed with caffe2. + """ + + def _get_tensor_mode(self): + return self._tensor_mode + + def _set_tensor_mode(self, v): + self._tensor_mode = v + + tensor_mode = property(_get_tensor_mode, _set_tensor_mode) + """ + If true, the model expects C2-style tensor only inputs/outputs format. + """ + + +class Caffe2RPN(Caffe2Compatible, rpn.RPN): + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape) + assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple( + cfg.MODEL.RPN.BBOX_REG_WEIGHTS + ) == (1.0, 1.0, 1.0, 1.0, 1.0) + return ret + + def _generate_proposals( + self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None + ): + assert isinstance(images, ImageList) + if self.tensor_mode: + im_info = images.image_sizes + else: + im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to( + images.tensor.device + ) + assert isinstance(im_info, torch.Tensor) + + rpn_rois_list = [] + rpn_roi_probs_list = [] + for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip( + objectness_logits_pred, + anchor_deltas_pred, + [b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()], + self.anchor_generator.strides, + ): + scores = scores.detach() + bbox_deltas = bbox_deltas.detach() + + rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals( + scores, + bbox_deltas, + im_info, + cell_anchors_tensor, + spatial_scale=1.0 / feat_stride, + pre_nms_topN=self.pre_nms_topk[self.training], + post_nms_topN=self.post_nms_topk[self.training], + nms_thresh=self.nms_thresh, + min_size=self.min_box_size, + # correct_transform_coords=True, # deprecated argument + angle_bound_on=True, # Default + angle_bound_lo=-180, + angle_bound_hi=180, + clip_angle_thresh=1.0, # Default + legacy_plus_one=False, + ) + rpn_rois_list.append(rpn_rois) + rpn_roi_probs_list.append(rpn_roi_probs) + + # For FPN in D2, in RPN all proposals from different levels are concated + # together, ranked and picked by top post_nms_topk. Then in ROIPooler + # it calculates level_assignments and calls the RoIAlign from + # the corresponding level. + + if len(objectness_logits_pred) == 1: + rpn_rois = rpn_rois_list[0] + rpn_roi_probs = rpn_roi_probs_list[0] + else: + assert len(rpn_rois_list) == len(rpn_roi_probs_list) + rpn_post_nms_topN = self.post_nms_topk[self.training] + + device = rpn_rois_list[0].device + input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)] + + # TODO remove this after confirming rpn_max_level/rpn_min_level + # is not needed in CollectRpnProposals. + feature_strides = list(self.anchor_generator.strides) + rpn_min_level = int(math.log2(feature_strides[0])) + rpn_max_level = int(math.log2(feature_strides[-1])) + assert (rpn_max_level - rpn_min_level + 1) == len( + rpn_rois_list + ), "CollectRpnProposals requires continuous levels" + + rpn_rois = torch.ops._caffe2.CollectRpnProposals( + input_list, + # NOTE: in current implementation, rpn_max_level and rpn_min_level + # are not needed, only the subtraction of two matters and it + # can be infer from the number of inputs. Keep them now for + # consistency. + rpn_max_level=2 + len(rpn_rois_list) - 1, + rpn_min_level=2, + rpn_post_nms_topN=rpn_post_nms_topN, + ) + rpn_rois = to_device(rpn_rois, device) + rpn_roi_probs = [] + + proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode) + return proposals, {} + + def forward(self, images, features, gt_instances=None): + assert not self.training + features = [features[f] for f in self.in_features] + objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features) + return self._generate_proposals( + images, + objectness_logits_pred, + anchor_deltas_pred, + gt_instances, + ) + + @staticmethod + def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode): + proposals = InstancesList( + im_info=im_info, + indices=rpn_rois[:, 0], + extra_fields={ + "proposal_boxes": Caffe2Boxes(rpn_rois), + "objectness_logits": (torch.Tensor, rpn_roi_probs), + }, + ) + if not tensor_mode: + proposals = InstancesList.to_d2_instances_list(proposals) + else: + proposals = [proposals] + return proposals + + +class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler): + @staticmethod + def c2_preprocess(box_lists): + assert all(isinstance(x, Boxes) for x in box_lists) + if all(isinstance(x, Caffe2Boxes) for x in box_lists): + # input is pure-tensor based + assert len(box_lists) == 1 + pooler_fmt_boxes = box_lists[0].tensor + else: + pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists) + return pooler_fmt_boxes + + def forward(self, x, box_lists): + assert not self.training + + pooler_fmt_boxes = self.c2_preprocess(box_lists) + num_level_assignments = len(self.level_poolers) + + if num_level_assignments == 1: + if isinstance(self.level_poolers[0], ROIAlignRotated): + c2_roi_align = torch.ops._caffe2.RoIAlignRotated + aligned = True + else: + c2_roi_align = torch.ops._caffe2.RoIAlign + aligned = self.level_poolers[0].aligned + + x0 = x[0] + if x0.is_quantized: + x0 = x0.dequantize() + + out = c2_roi_align( + x0, + pooler_fmt_boxes, + order="NCHW", + spatial_scale=float(self.level_poolers[0].spatial_scale), + pooled_h=int(self.output_size[0]), + pooled_w=int(self.output_size[1]), + sampling_ratio=int(self.level_poolers[0].sampling_ratio), + aligned=aligned, + ) + return out + + device = pooler_fmt_boxes.device + assert ( + self.max_level - self.min_level + 1 == 4 + ), "Currently DistributeFpnProposals only support 4 levels" + fpn_outputs = torch.ops._caffe2.DistributeFpnProposals( + to_device(pooler_fmt_boxes, "cpu"), + roi_canonical_scale=self.canonical_box_size, + roi_canonical_level=self.canonical_level, + roi_max_level=self.max_level, + roi_min_level=self.min_level, + legacy_plus_one=False, + ) + fpn_outputs = [to_device(x, device) for x in fpn_outputs] + + rois_fpn_list = fpn_outputs[:-1] + rois_idx_restore_int32 = fpn_outputs[-1] + + roi_feat_fpn_list = [] + for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers): + if isinstance(pooler, ROIAlignRotated): + c2_roi_align = torch.ops._caffe2.RoIAlignRotated + aligned = True + else: + c2_roi_align = torch.ops._caffe2.RoIAlign + aligned = bool(pooler.aligned) + + if x_level.is_quantized: + x_level = x_level.dequantize() + + roi_feat_fpn = c2_roi_align( + x_level, + roi_fpn, + order="NCHW", + spatial_scale=float(pooler.spatial_scale), + pooled_h=int(self.output_size[0]), + pooled_w=int(self.output_size[1]), + sampling_ratio=int(pooler.sampling_ratio), + aligned=aligned, + ) + roi_feat_fpn_list.append(roi_feat_fpn) + + roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0) + assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, ( + "Caffe2 export requires tracing with a model checkpoint + input that can produce valid" + " detections. But no detections were obtained with the given checkpoint and input!" + ) + roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32) + return roi_feat + + +def caffe2_fast_rcnn_outputs_inference(tensor_mode, box_predictor, predictions, proposals): + """equivalent to FastRCNNOutputLayers.inference""" + num_classes = box_predictor.num_classes + score_thresh = box_predictor.test_score_thresh + nms_thresh = box_predictor.test_nms_thresh + topk_per_image = box_predictor.test_topk_per_image + is_rotated = len(box_predictor.box2box_transform.weights) == 5 + + if is_rotated: + box_dim = 5 + assert box_predictor.box2box_transform.weights[4] == 1, ( + "The weights for Rotated BBoxTransform in C2 have only 4 dimensions," + + " thus enforcing the angle weight to be 1 for now" + ) + box2box_transform_weights = box_predictor.box2box_transform.weights[:4] + else: + box_dim = 4 + box2box_transform_weights = box_predictor.box2box_transform.weights + + class_logits, box_regression = predictions + if num_classes + 1 == class_logits.shape[1]: + class_prob = F.softmax(class_logits, -1) + else: + assert num_classes == class_logits.shape[1] + class_prob = F.sigmoid(class_logits) + # BoxWithNMSLimit will infer num_classes from the shape of the class_prob + # So append a zero column as placeholder for the background class + class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1) + + assert box_regression.shape[1] % box_dim == 0 + cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1 + + input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1 + + proposal_boxes = proposals[0].proposal_boxes + if isinstance(proposal_boxes, Caffe2Boxes): + rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals]) + elif isinstance(proposal_boxes, RotatedBoxes): + rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals]) + elif isinstance(proposal_boxes, Boxes): + rois = Boxes.cat([p.proposal_boxes for p in proposals]) + else: + raise NotImplementedError( + 'Expected proposals[0].proposal_boxes to be type "Boxes", ' + f"instead got {type(proposal_boxes)}" + ) + + device, dtype = rois.tensor.device, rois.tensor.dtype + if input_tensor_mode: + im_info = proposals[0].image_size + rois = rois.tensor + else: + im_info = torch.tensor([[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]]) + batch_ids = cat( + [ + torch.full((b, 1), i, dtype=dtype, device=device) + for i, b in enumerate(len(p) for p in proposals) + ], + dim=0, + ) + rois = torch.cat([batch_ids, rois.tensor], dim=1) + + roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform( + to_device(rois, "cpu"), + to_device(box_regression, "cpu"), + to_device(im_info, "cpu"), + weights=box2box_transform_weights, + apply_scale=True, + rotated=is_rotated, + angle_bound_on=True, + angle_bound_lo=-180, + angle_bound_hi=180, + clip_angle_thresh=1.0, + legacy_plus_one=False, + ) + roi_pred_bbox = to_device(roi_pred_bbox, device) + roi_batch_splits = to_device(roi_batch_splits, device) + + nms_outputs = torch.ops._caffe2.BoxWithNMSLimit( + to_device(class_prob, "cpu"), + to_device(roi_pred_bbox, "cpu"), + to_device(roi_batch_splits, "cpu"), + score_thresh=float(score_thresh), + nms=float(nms_thresh), + detections_per_im=int(topk_per_image), + soft_nms_enabled=False, + soft_nms_method="linear", + soft_nms_sigma=0.5, + soft_nms_min_score_thres=0.001, + rotated=is_rotated, + cls_agnostic_bbox_reg=cls_agnostic_bbox_reg, + input_boxes_include_bg_cls=False, + output_classes_include_bg_cls=False, + legacy_plus_one=False, + ) + roi_score_nms = to_device(nms_outputs[0], device) + roi_bbox_nms = to_device(nms_outputs[1], device) + roi_class_nms = to_device(nms_outputs[2], device) + roi_batch_splits_nms = to_device(nms_outputs[3], device) + roi_keeps_nms = to_device(nms_outputs[4], device) + roi_keeps_size_nms = to_device(nms_outputs[5], device) + if not tensor_mode: + roi_class_nms = roi_class_nms.to(torch.int64) + + roi_batch_ids = cat( + [ + torch.full((b, 1), i, dtype=dtype, device=device) + for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms) + ], + dim=0, + ) + + roi_class_nms = alias(roi_class_nms, "class_nms") + roi_score_nms = alias(roi_score_nms, "score_nms") + roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms") + roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms") + roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms") + roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms") + + results = InstancesList( + im_info=im_info, + indices=roi_batch_ids[:, 0], + extra_fields={ + "pred_boxes": Caffe2Boxes(roi_bbox_nms), + "scores": roi_score_nms, + "pred_classes": roi_class_nms, + }, + ) + + if not tensor_mode: + results = InstancesList.to_d2_instances_list(results) + batch_splits = roi_batch_splits_nms.int().tolist() + kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits)) + else: + results = [results] + kept_indices = [roi_keeps_nms] + + return results, kept_indices + + +class Caffe2FastRCNNOutputsInference: + def __init__(self, tensor_mode): + self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode + + def __call__(self, box_predictor, predictions, proposals): + return caffe2_fast_rcnn_outputs_inference( + self.tensor_mode, box_predictor, predictions, proposals + ) + + +def caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances): + """equivalent to mask_head.mask_rcnn_inference""" + if all(isinstance(x, InstancesList) for x in pred_instances): + assert len(pred_instances) == 1 + mask_probs_pred = pred_mask_logits.sigmoid() + mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs") + pred_instances[0].set("pred_masks", mask_probs_pred) + else: + mask_rcnn_inference(pred_mask_logits, pred_instances) + + +class Caffe2MaskRCNNInference: + def __call__(self, pred_mask_logits, pred_instances): + return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) + + +def caffe2_keypoint_rcnn_inference(use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances): + # just return the keypoint heatmap for now, + # there will be option to call HeatmapMaxKeypointOp + output = alias(pred_keypoint_logits, "kps_score") + if all(isinstance(x, InstancesList) for x in pred_instances): + assert len(pred_instances) == 1 + if use_heatmap_max_keypoint: + device = output.device + output = torch.ops._caffe2.HeatmapMaxKeypoint( + to_device(output, "cpu"), + pred_instances[0].pred_boxes.tensor, + should_output_softmax=True, # worth make it configerable? + ) + output = to_device(output, device) + output = alias(output, "keypoints_out") + pred_instances[0].set("pred_keypoints", output) + return pred_keypoint_logits + + +class Caffe2KeypointRCNNInference: + def __init__(self, use_heatmap_max_keypoint): + self.use_heatmap_max_keypoint = use_heatmap_max_keypoint + + def __call__(self, pred_keypoint_logits, pred_instances): + return caffe2_keypoint_rcnn_inference( + self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances + ) diff --git a/data_processing/detectron2/detectron2/export/caffe2_export.py b/data_processing/detectron2/detectron2/export/caffe2_export.py new file mode 100644 index 0000000..d609c27 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/caffe2_export.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import io +import logging +import numpy as np +from typing import List +import onnx +import onnx.optimizer +import torch +from caffe2.proto import caffe2_pb2 +from caffe2.python import core +from caffe2.python.onnx.backend import Caffe2Backend +from tabulate import tabulate +from termcolor import colored +from torch.onnx import OperatorExportTypes + +from .shared import ( + ScopedWS, + construct_init_net_from_params, + fuse_alias_placeholder, + fuse_copy_between_cpu_and_gpu, + get_params_from_init_net, + group_norm_replace_aten_with_caffe2, + infer_device_type, + remove_dead_end_ops, + remove_reshape_for_fc, + save_graph, +) + +logger = logging.getLogger(__name__) + + +def export_onnx_model(model, inputs): + """ + Trace and export a model to onnx format. + + Args: + model (nn.Module): + inputs (tuple[args]): the model will be called by `model(*inputs)` + + Returns: + an onnx model + """ + assert isinstance(model, torch.nn.Module) + + # make sure all modules are in eval mode, onnx may change the training state + # of the module if the states are not consistent + def _check_eval(module): + assert not module.training + + model.apply(_check_eval) + + # Export the model to ONNX + with torch.no_grad(): + with io.BytesIO() as f: + torch.onnx.export( + model, + inputs, + f, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + # verbose=True, # NOTE: uncomment this for debugging + # export_params=True, + ) + onnx_model = onnx.load_from_string(f.getvalue()) + + return onnx_model + + +def _op_stats(net_def): + type_count = {} + for t in [op.type for op in net_def.op]: + type_count[t] = type_count.get(t, 0) + 1 + type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet + type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count + return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list) + + +def _assign_device_option( + predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor] +): + """ + ONNX exported network doesn't have concept of device, assign necessary + device option for each op in order to make it runable on GPU runtime. + """ + + def _get_device_type(torch_tensor): + assert torch_tensor.device.type in ["cpu", "cuda"] + assert torch_tensor.device.index == 0 + return torch_tensor.device.type + + def _assign_op_device_option(net_proto, net_ssa, blob_device_types): + for op, ssa_i in zip(net_proto.op, net_ssa): + if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]: + op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0)) + else: + devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]] + assert all(d == devices[0] for d in devices) + if devices[0] == "cuda": + op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0)) + + # update ops in predict_net + predict_net_input_device_types = { + (name, 0): _get_device_type(tensor) + for name, tensor in zip(predict_net.external_input, tensor_inputs) + } + predict_net_device_types = infer_device_type( + predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch" + ) + predict_net_ssa, _ = core.get_ssa(predict_net) + _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types) + + # update ops in init_net + init_net_ssa, versions = core.get_ssa(init_net) + init_net_output_device_types = { + (name, versions[name]): predict_net_device_types[(name, 0)] + for name in init_net.external_output + } + init_net_device_types = infer_device_type( + init_net, known_status=init_net_output_device_types, device_name_style="pytorch" + ) + _assign_op_device_option(init_net, init_net_ssa, init_net_device_types) + + +def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]): + """ + Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX. + + Arg: + model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py + tensor_inputs: a list of tensors that caffe2 model takes as input. + """ + model = copy.deepcopy(model) + assert isinstance(model, torch.nn.Module) + assert hasattr(model, "encode_additional_info") + + # Export via ONNX + logger.info( + "Exporting a {} model via ONNX ...".format(type(model).__name__) + + " Some warnings from ONNX are expected and are usually not to worry about." + ) + onnx_model = export_onnx_model(model, (tensor_inputs,)) + # Convert ONNX model to Caffe2 protobuf + init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) + ops_table = [[op.type, op.input, op.output] for op in predict_net.op] + table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe") + logger.info( + "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan") + ) + + # Apply protobuf optimization + fuse_alias_placeholder(predict_net, init_net) + if any(t.device.type != "cpu" for t in tensor_inputs): + fuse_copy_between_cpu_and_gpu(predict_net) + remove_dead_end_ops(init_net) + _assign_device_option(predict_net, init_net, tensor_inputs) + params, device_options = get_params_from_init_net(init_net) + predict_net, params = remove_reshape_for_fc(predict_net, params) + init_net = construct_init_net_from_params(params, device_options) + group_norm_replace_aten_with_caffe2(predict_net) + + # Record necessary information for running the pb model in Detectron2 system. + model.encode_additional_info(predict_net, init_net) + + logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net))) + logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net))) + + return predict_net, init_net + + +def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path): + """ + Run the caffe2 model on given inputs, recording the shape and draw the graph. + + predict_net/init_net: caffe2 model. + tensor_inputs: a list of tensors that caffe2 model takes as input. + graph_save_path: path for saving graph of exported model. + """ + + logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path)) + save_graph(predict_net, graph_save_path, op_only=False) + + # Run the exported Caffe2 net + logger.info("Running ONNX exported model ...") + with ScopedWS("__ws_tmp__", True) as ws: + ws.RunNetOnce(init_net) + initialized_blobs = set(ws.Blobs()) + uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs] + for name, blob in zip(uninitialized, tensor_inputs): + ws.FeedBlob(name, blob) + + try: + ws.RunNetOnce(predict_net) + except RuntimeError as e: + logger.warning("Encountered RuntimeError: \n{}".format(str(e))) + + ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()} + blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)} + + logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path)) + save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes) + + return ws_blobs diff --git a/data_processing/detectron2/detectron2/export/caffe2_inference.py b/data_processing/detectron2/detectron2/export/caffe2_inference.py new file mode 100644 index 0000000..deb886c --- /dev/null +++ b/data_processing/detectron2/detectron2/export/caffe2_inference.py @@ -0,0 +1,161 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +from itertools import count +import torch +from caffe2.proto import caffe2_pb2 +from caffe2.python import core + +from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format +from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type + +logger = logging.getLogger(__name__) + + +# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ====== +class ProtobufModel(torch.nn.Module): + """ + Wrapper of a caffe2's protobuf model. + It works just like nn.Module, but running caffe2 under the hood. + Input/Output are tuple[tensor] that match the caffe2 net's external_input/output. + """ + + _ids = count(0) + + def __init__(self, predict_net, init_net): + logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...") + super().__init__() + assert isinstance(predict_net, caffe2_pb2.NetDef) + assert isinstance(init_net, caffe2_pb2.NetDef) + # create unique temporary workspace for each instance + self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids)) + self.net = core.Net(predict_net) + + logger.info("Running init_net once to fill the parameters ...") + with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws: + ws.RunNetOnce(init_net) + uninitialized_external_input = [] + for blob in self.net.Proto().external_input: + if blob not in ws.Blobs(): + uninitialized_external_input.append(blob) + ws.CreateBlob(blob) + ws.CreateNet(self.net) + + self._error_msgs = set() + self._input_blobs = uninitialized_external_input + + def _infer_output_devices(self, inputs): + """ + Returns: + list[str]: list of device for each external output + """ + + def _get_device_type(torch_tensor): + assert torch_tensor.device.type in ["cpu", "cuda"] + assert torch_tensor.device.index == 0 + return torch_tensor.device.type + + predict_net = self.net.Proto() + input_device_types = { + (name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs) + } + device_type_map = infer_device_type( + predict_net, known_status=input_device_types, device_name_style="pytorch" + ) + ssa, versions = core.get_ssa(predict_net) + versioned_outputs = [(name, versions[name]) for name in predict_net.external_output] + output_devices = [device_type_map[outp] for outp in versioned_outputs] + return output_devices + + def forward(self, inputs): + """ + Args: + inputs (tuple[torch.Tensor]) + + Returns: + tuple[torch.Tensor] + """ + assert len(inputs) == len(self._input_blobs), ( + f"Length of inputs ({len(inputs)}) " + f"doesn't match the required input blobs: {self._input_blobs}" + ) + + with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws: + for b, tensor in zip(self._input_blobs, inputs): + ws.FeedBlob(b, tensor) + + try: + ws.RunNet(self.net.Proto().name) + except RuntimeError as e: + if not str(e) in self._error_msgs: + self._error_msgs.add(str(e)) + logger.warning("Encountered new RuntimeError: \n{}".format(str(e))) + logger.warning("Catch the error and use partial results.") + + c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output] + # Remove outputs of current run, this is necessary in order to + # prevent fetching the result from previous run if the model fails + # in the middle. + for b in self.net.Proto().external_output: + # Needs to create uninitialized blob to make the net runable. + # This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b), + # but there'no such API. + ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).") + + # Cast output to torch.Tensor on the desired device + output_devices = ( + self._infer_output_devices(inputs) + if any(t.device.type != "cpu" for t in inputs) + else ["cpu" for _ in self.net.Proto().external_output] + ) + + outputs = [] + for name, c2_output, device in zip( + self.net.Proto().external_output, c2_outputs, output_devices + ): + if not isinstance(c2_output, np.ndarray): + raise RuntimeError( + "Invalid output for blob {}, received: {}".format(name, c2_output) + ) + outputs.append(torch.tensor(c2_output).to(device=device)) + return tuple(outputs) + + +class ProtobufDetectionModel(torch.nn.Module): + """ + A class works just like a pytorch meta arch in terms of inference, but running + caffe2 model under the hood. + """ + + def __init__(self, predict_net, init_net, *, convert_outputs=None): + """ + Args: + predict_net, init_net (core.Net): caffe2 nets + convert_outptus (callable): a function that converts caffe2 + outputs to the same format of the original pytorch model. + By default, use the one defined in the caffe2 meta_arch. + """ + super().__init__() + self.protobuf_model = ProtobufModel(predict_net, init_net) + self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0) + self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii") + + if convert_outputs is None: + meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN") + meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")] + self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net) + else: + self._convert_outputs = convert_outputs + + def _convert_inputs(self, batched_inputs): + # currently all models convert inputs in the same way + return convert_batched_inputs_to_c2_format( + batched_inputs, self.size_divisibility, self.device + ) + + def forward(self, batched_inputs): + c2_inputs = self._convert_inputs(batched_inputs) + c2_results = self.protobuf_model(c2_inputs) + c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results)) + return self._convert_outputs(batched_inputs, c2_inputs, c2_results) diff --git a/data_processing/detectron2/detectron2/export/caffe2_modeling.py b/data_processing/detectron2/detectron2/export/caffe2_modeling.py new file mode 100644 index 0000000..3e675c4 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/caffe2_modeling.py @@ -0,0 +1,420 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import io +import struct +import types +import torch + +from detectron2.modeling import meta_arch +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads import keypoint_head +from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes + +from .c10 import Caffe2Compatible +from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn +from .shared import ( + alias, + check_set_pb_arg, + get_pb_arg_floats, + get_pb_arg_valf, + get_pb_arg_vali, + get_pb_arg_vals, + mock_torch_nn_functional_interpolate, +) + + +def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False): + """ + A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor]) + to detectron2's format (i.e. list of Instances instance). + This only works when the model follows the Caffe2 detectron's naming convention. + + Args: + image_sizes (List[List[int, int]]): [H, W] of every image. + tensor_outputs (Dict[str, Tensor]): external_output to its tensor. + + force_mask_on (Bool): if true, the it make sure there'll be pred_masks even + if the mask is not found from tensor_outputs (usually due to model crash) + """ + + results = [Instances(image_size) for image_size in image_sizes] + + batch_splits = tensor_outputs.get("batch_splits", None) + if batch_splits: + raise NotImplementedError() + assert len(image_sizes) == 1 + result = results[0] + + bbox_nms = tensor_outputs["bbox_nms"] + score_nms = tensor_outputs["score_nms"] + class_nms = tensor_outputs["class_nms"] + # Detection will always success because Conv support 0-batch + assert bbox_nms is not None + assert score_nms is not None + assert class_nms is not None + if bbox_nms.shape[1] == 5: + result.pred_boxes = RotatedBoxes(bbox_nms) + else: + result.pred_boxes = Boxes(bbox_nms) + result.scores = score_nms + result.pred_classes = class_nms.to(torch.int64) + + mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None) + if mask_fcn_probs is not None: + # finish the mask pred + mask_probs_pred = mask_fcn_probs + num_masks = mask_probs_pred.shape[0] + class_pred = result.pred_classes + indices = torch.arange(num_masks, device=class_pred.device) + mask_probs_pred = mask_probs_pred[indices, class_pred][:, None] + result.pred_masks = mask_probs_pred + elif force_mask_on: + # NOTE: there's no way to know the height/width of mask here, it won't be + # used anyway when batch size is 0, so just set them to 0. + result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8) + + keypoints_out = tensor_outputs.get("keypoints_out", None) + kps_score = tensor_outputs.get("kps_score", None) + if keypoints_out is not None: + # keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob) + keypoints_tensor = keypoints_out + # NOTE: it's possible that prob is not calculated if "should_output_softmax" + # is set to False in HeatmapMaxKeypoint, so just using raw score, seems + # it doesn't affect mAP. TODO: check more carefully. + keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]] + result.pred_keypoints = keypoint_xyp + elif kps_score is not None: + # keypoint heatmap to sparse data structure + pred_keypoint_logits = kps_score + keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result]) + + return results + + +def _cast_to_f32(f64): + return struct.unpack("f", struct.pack("f", f64))[0] + + +def set_caffe2_compatible_tensor_mode(model, enable=True): + def _fn(m): + if isinstance(m, Caffe2Compatible): + m.tensor_mode = enable + + model.apply(_fn) + + +def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device): + """ + See get_caffe2_inputs() below. + """ + assert all(isinstance(x, dict) for x in batched_inputs) + assert all(x["image"].dim() == 3 for x in batched_inputs) + + images = [x["image"] for x in batched_inputs] + images = ImageList.from_tensors(images, size_divisibility) + + im_info = [] + for input_per_image, image_size in zip(batched_inputs, images.image_sizes): + target_height = input_per_image.get("height", image_size[0]) + target_width = input_per_image.get("width", image_size[1]) # noqa + # NOTE: The scale inside im_info is kept as convention and for providing + # post-processing information if further processing is needed. For + # current Caffe2 model definitions that don't include post-processing inside + # the model, this number is not used. + # NOTE: There can be a slight difference between width and height + # scales, using a single number can results in numerical difference + # compared with D2's post-processing. + scale = target_height / image_size[0] + im_info.append([image_size[0], image_size[1], scale]) + im_info = torch.Tensor(im_info) + + return images.tensor.to(device), im_info.to(device) + + +class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module): + """ + Base class for caffe2-compatible implementation of a meta architecture. + The forward is traceable and its traced graph can be converted to caffe2 + graph through ONNX. + """ + + def __init__(self, cfg, torch_model, enable_tensor_mode=True): + """ + Args: + cfg (CfgNode): + torch_model (nn.Module): the detectron2 model (meta_arch) to be + converted. + """ + super().__init__() + self._wrapped_model = torch_model + self.eval() + set_caffe2_compatible_tensor_mode(self, enable_tensor_mode) + + def get_caffe2_inputs(self, batched_inputs): + """ + Convert pytorch-style structured inputs to caffe2-style inputs that + are tuples of tensors. + + Args: + batched_inputs (list[dict]): inputs to a detectron2 model + in its standard format. Each dict has "image" (CHW tensor), and optionally + "height" and "width". + + Returns: + tuple[Tensor]: + tuple of tensors that will be the inputs to the + :meth:`forward` method. For existing models, the first + is an NCHW tensor (padded and batched); the second is + a im_info Nx3 tensor, where the rows are + (height, width, unused legacy parameter) + """ + return convert_batched_inputs_to_c2_format( + batched_inputs, + self._wrapped_model.backbone.size_divisibility, + self._wrapped_model.device, + ) + + def encode_additional_info(self, predict_net, init_net): + """ + Save extra metadata that will be used by inference in the output protobuf. + """ + pass + + def forward(self, inputs): + """ + Run the forward in caffe2-style. It has to use caffe2-compatible ops + and the method will be used for tracing. + + Args: + inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`. + They will be the inputs of the converted caffe2 graph. + + Returns: + tuple[Tensor]: output tensors. They will be the outputs of the + converted caffe2 graph. + """ + raise NotImplementedError + + def _caffe2_preprocess_image(self, inputs): + """ + Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward. + It normalizes the input images, and the final caffe2 graph assumes the + inputs have been batched already. + """ + data, im_info = inputs + data = alias(data, "data") + im_info = alias(im_info, "im_info") + mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std + normalized_data = (data - mean) / std + normalized_data = alias(normalized_data, "normalized_data") + + # Pack (data, im_info) into ImageList which is recognized by self.inference. + images = ImageList(tensor=normalized_data, image_sizes=im_info) + return images + + @staticmethod + def get_outputs_converter(predict_net, init_net): + """ + Creates a function that converts outputs of the caffe2 model to + detectron2's standard format. + The function uses information in `predict_net` and `init_net` that are + available at inferene time. Therefore the function logic can be used in inference. + + The returned function has the following signature: + + def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs + + Where + + * batched_inputs (list[dict]): the original input format of the meta arch + * c2_inputs (tuple[Tensor]): the caffe2 inputs. + * c2_results (dict[str, Tensor]): the caffe2 output format, + corresponding to the outputs of the :meth:`forward` function. + * detectron2_outputs: the original output format of the meta arch. + + This function can be used to compare the outputs of the original meta arch and + the converted caffe2 graph. + + Returns: + callable: a callable of the above signature. + """ + raise NotImplementedError + + +class Caffe2GeneralizedRCNN(Caffe2MetaArch): + def __init__(self, cfg, torch_model, enable_tensor_mode=True): + assert isinstance(torch_model, meta_arch.GeneralizedRCNN) + torch_model = patch_generalized_rcnn(torch_model) + super().__init__(cfg, torch_model, enable_tensor_mode) + + try: + use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT + except AttributeError: + use_heatmap_max_keypoint = False + self.roi_heads_patcher = ROIHeadsPatcher( + self._wrapped_model.roi_heads, use_heatmap_max_keypoint + ) + if self.tensor_mode: + self.roi_heads_patcher.patch_roi_heads() + + def encode_additional_info(self, predict_net, init_net): + size_divisibility = self._wrapped_model.backbone.size_divisibility + check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility) + check_set_pb_arg( + predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii") + ) + check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN") + + @mock_torch_nn_functional_interpolate() + def forward(self, inputs): + if not self.tensor_mode: + return self._wrapped_model.inference(inputs) + images = self._caffe2_preprocess_image(inputs) + features = self._wrapped_model.backbone(images.tensor) + proposals, _ = self._wrapped_model.proposal_generator(images, features) + detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals) + return tuple(detector_results[0].flatten()) + + @staticmethod + def get_outputs_converter(predict_net, init_net): + def f(batched_inputs, c2_inputs, c2_results): + _, im_info = c2_inputs + image_sizes = [[int(im[0]), int(im[1])] for im in im_info] + results = assemble_rcnn_outputs_by_name(image_sizes, c2_results) + return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes) + + return f + + +class Caffe2RetinaNet(Caffe2MetaArch): + def __init__(self, cfg, torch_model): + assert isinstance(torch_model, meta_arch.RetinaNet) + super().__init__(cfg, torch_model) + + @mock_torch_nn_functional_interpolate() + def forward(self, inputs): + assert self.tensor_mode + images = self._caffe2_preprocess_image(inputs) + + # explicitly return the images sizes to avoid removing "im_info" by ONNX + # since it's not used in the forward path + return_tensors = [images.image_sizes] + + features = self._wrapped_model.backbone(images.tensor) + features = [features[f] for f in self._wrapped_model.head_in_features] + for i, feature_i in enumerate(features): + features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True) + return_tensors.append(features[i]) + + pred_logits, pred_anchor_deltas = self._wrapped_model.head(features) + for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)): + return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i))) + return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i))) + + return tuple(return_tensors) + + def encode_additional_info(self, predict_net, init_net): + size_divisibility = self._wrapped_model.backbone.size_divisibility + check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility) + check_set_pb_arg( + predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii") + ) + check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet") + + # Inference parameters: + check_set_pb_arg( + predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh) + ) + check_set_pb_arg( + predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates + ) + check_set_pb_arg( + predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh) + ) + check_set_pb_arg( + predict_net, + "max_detections_per_image", + "i", + self._wrapped_model.max_detections_per_image, + ) + + check_set_pb_arg( + predict_net, + "bbox_reg_weights", + "floats", + [_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights], + ) + self._encode_anchor_generator_cfg(predict_net) + + def _encode_anchor_generator_cfg(self, predict_net): + # serialize anchor_generator for future use + serialized_anchor_generator = io.BytesIO() + torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator) + # Ideally we can put anchor generating inside the model, then we don't + # need to store this information. + bytes = serialized_anchor_generator.getvalue() + check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes) + + @staticmethod + def get_outputs_converter(predict_net, init_net): + self = types.SimpleNamespace() + serialized_anchor_generator = io.BytesIO( + get_pb_arg_vals(predict_net, "serialized_anchor_generator", None) + ) + self.anchor_generator = torch.load(serialized_anchor_generator) + bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None) + self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights)) + self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None) + self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None) + self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None) + self.max_detections_per_image = get_pb_arg_vali( + predict_net, "max_detections_per_image", None + ) + + # hack to reuse inference code from RetinaNet + for meth in [ + "forward_inference", + "inference_single_image", + "_transpose_dense_predictions", + "_decode_multi_level_predictions", + "_decode_per_level_predictions", + ]: + setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self)) + + def f(batched_inputs, c2_inputs, c2_results): + _, im_info = c2_inputs + image_sizes = [[int(im[0]), int(im[1])] for im in im_info] + dummy_images = ImageList( + torch.randn( + ( + len(im_info), + 3, + ) + + tuple(image_sizes[0]) + ), + image_sizes, + ) + + num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")]) + pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)] + pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)] + + # For each feature level, feature should have the same batch size and + # spatial dimension as the box_cls and box_delta. + dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits] + # self.num_classess can be inferred + self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4) + + results = self.forward_inference( + dummy_images, dummy_features, [pred_logits, pred_anchor_deltas] + ) + return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes) + + return f + + +META_ARCH_CAFFE2_EXPORT_TYPE_MAP = { + "GeneralizedRCNN": Caffe2GeneralizedRCNN, + "RetinaNet": Caffe2RetinaNet, +} diff --git a/data_processing/detectron2/detectron2/export/caffe2_patch.py b/data_processing/detectron2/detectron2/export/caffe2_patch.py new file mode 100644 index 0000000..2da70ae --- /dev/null +++ b/data_processing/detectron2/detectron2/export/caffe2_patch.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +from unittest import mock +import torch + +from detectron2.modeling import poolers +from detectron2.modeling.proposal_generator import rpn +from detectron2.modeling.roi_heads import keypoint_head, mask_head +from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers + +from .c10 import ( + Caffe2Compatible, + Caffe2FastRCNNOutputsInference, + Caffe2KeypointRCNNInference, + Caffe2MaskRCNNInference, + Caffe2ROIPooler, + Caffe2RPN, + caffe2_fast_rcnn_outputs_inference, + caffe2_keypoint_rcnn_inference, + caffe2_mask_rcnn_inference, +) + + +class GenericMixin(object): + pass + + +class Caffe2CompatibleConverter(object): + """ + A GenericUpdater which implements the `create_from` interface, by modifying + module object and assign it with another class replaceCls. + """ + + def __init__(self, replaceCls): + self.replaceCls = replaceCls + + def create_from(self, module): + # update module's class to the new class + assert isinstance(module, torch.nn.Module) + if issubclass(self.replaceCls, GenericMixin): + # replaceCls should act as mixin, create a new class on-the-fly + new_class = type( + "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), + (self.replaceCls, module.__class__), + {}, # {"new_method": lambda self: ...}, + ) + module.__class__ = new_class + else: + # replaceCls is complete class, this allow arbitrary class swap + module.__class__ = self.replaceCls + + # initialize Caffe2Compatible + if isinstance(module, Caffe2Compatible): + module.tensor_mode = False + + return module + + +def patch(model, target, updater, *args, **kwargs): + """ + recursively (post-order) update all modules with the target type and its + subclasses, make a initialization/composition/inheritance/... via the + updater.create_from. + """ + for name, module in model.named_children(): + model._modules[name] = patch(module, target, updater, *args, **kwargs) + if isinstance(model, target): + return updater.create_from(model, *args, **kwargs) + return model + + +def patch_generalized_rcnn(model): + ccc = Caffe2CompatibleConverter + model = patch(model, rpn.RPN, ccc(Caffe2RPN)) + model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) + + return model + + +@contextlib.contextmanager +def mock_fastrcnn_outputs_inference( + tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers +): + with mock.patch.object( + box_predictor_type, + "inference", + autospec=True, + side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +@contextlib.contextmanager +def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): + with mock.patch( + "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +@contextlib.contextmanager +def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): + with mock.patch( + "{}.keypoint_rcnn_inference".format(patched_module), + side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +class ROIHeadsPatcher: + def __init__(self, heads, use_heatmap_max_keypoint): + self.heads = heads + self.use_heatmap_max_keypoint = use_heatmap_max_keypoint + self.previous_patched = {} + + @contextlib.contextmanager + def mock_roi_heads(self, tensor_mode=True): + """ + Patching several inference functions inside ROIHeads and its subclasses + + Args: + tensor_mode (bool): whether the inputs/outputs are caffe2's tensor + format or not. Default to True. + """ + # NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference` + # are called inside the same file as BaseXxxHead due to using mock.patch. + kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ + mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ + + mock_ctx_managers = [ + mock_fastrcnn_outputs_inference( + tensor_mode=tensor_mode, + check=True, + box_predictor_type=type(self.heads.box_predictor), + ) + ] + if getattr(self.heads, "keypoint_on", False): + mock_ctx_managers += [ + mock_keypoint_rcnn_inference( + tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint + ) + ] + if getattr(self.heads, "mask_on", False): + mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] + + with contextlib.ExitStack() as stack: # python 3.3+ + for mgr in mock_ctx_managers: + stack.enter_context(mgr) + yield + + def patch_roi_heads(self, tensor_mode=True): + self.previous_patched["box_predictor"] = self.heads.box_predictor.inference + self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference + self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference + + def patched_fastrcnn_outputs_inference(predictions, proposal): + return caffe2_fast_rcnn_outputs_inference( + True, self.heads.box_predictor, predictions, proposal + ) + + self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference + + if getattr(self.heads, "keypoint_on", False): + + def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances): + return caffe2_keypoint_rcnn_inference( + self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances + ) + + keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference + + if getattr(self.heads, "mask_on", False): + + def patched_mask_rcnn_inference(pred_mask_logits, pred_instances): + return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances) + + mask_head.mask_rcnn_inference = patched_mask_rcnn_inference + + def unpatch_roi_heads(self): + self.heads.box_predictor.inference = self.previous_patched["box_predictor"] + keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"] + mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"] diff --git a/data_processing/detectron2/detectron2/export/flatten.py b/data_processing/detectron2/detectron2/export/flatten.py new file mode 100644 index 0000000..f5ba429 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/flatten.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import collections +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple +import torch +from torch import nn + +from detectron2.structures import Boxes, Instances, ROIMasks +from detectron2.utils.registry import _convert_target_to_string, locate + +from .torchscript_patch import patch_builtin_len + + +@dataclass +class Schema: + """ + A Schema defines how to flatten a possibly hierarchical object into tuple of + primitive objects, so it can be used as inputs/outputs of PyTorch's tracing. + + PyTorch does not support tracing a function that produces rich output + structures (e.g. dict, Instances, Boxes). To trace such a function, we + flatten the rich object into tuple of tensors, and return this tuple of tensors + instead. Meanwhile, we also need to know how to "rebuild" the original object + from the flattened results, so we can evaluate the flattened results. + A Schema defines how to flatten an object, and while flattening it, it records + necessary schemas so that the object can be rebuilt using the flattened outputs. + + The flattened object and the schema object is returned by ``.flatten`` classmethod. + Then the original object can be rebuilt with the ``__call__`` method of schema. + + A Schema is a dataclass that can be serialized easily. + """ + + # inspired by FetchMapper in tensorflow/python/client/session.py + + @classmethod + def flatten(cls, obj): + raise NotImplementedError + + def __call__(self, values): + raise NotImplementedError + + @staticmethod + def _concat(values): + ret = () + sizes = [] + for v in values: + assert isinstance(v, tuple), "Flattened results must be a tuple" + ret = ret + v + sizes.append(len(v)) + return ret, sizes + + @staticmethod + def _split(values, sizes): + if len(sizes): + expected_len = sum(sizes) + assert ( + len(values) == expected_len + ), f"Values has length {len(values)} but expect length {expected_len}." + ret = [] + for k in range(len(sizes)): + begin, end = sum(sizes[:k]), sum(sizes[: k + 1]) + ret.append(values[begin:end]) + return ret + + +@dataclass +class ListSchema(Schema): + schemas: List[Schema] # the schemas that define how to flatten each element in the list + sizes: List[int] # the flattened length of each element + + def __call__(self, values): + values = self._split(values, self.sizes) + if len(values) != len(self.schemas): + raise ValueError( + f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!" + ) + values = [m(v) for m, v in zip(self.schemas, values)] + return list(values) + + @classmethod + def flatten(cls, obj): + res = [flatten_to_tuple(k) for k in obj] + values, sizes = cls._concat([k[0] for k in res]) + return values, cls([k[1] for k in res], sizes) + + +@dataclass +class TupleSchema(ListSchema): + def __call__(self, values): + return tuple(super().__call__(values)) + + +@dataclass +class IdentitySchema(Schema): + def __call__(self, values): + return values[0] + + @classmethod + def flatten(cls, obj): + return (obj,), cls() + + +@dataclass +class DictSchema(ListSchema): + keys: List[str] + + def __call__(self, values): + values = super().__call__(values) + return dict(zip(self.keys, values)) + + @classmethod + def flatten(cls, obj): + for k in obj.keys(): + if not isinstance(k, str): + raise KeyError("Only support flattening dictionaries if keys are str.") + keys = sorted(obj.keys()) + values = [obj[k] for k in keys] + ret, schema = ListSchema.flatten(values) + return ret, cls(schema.schemas, schema.sizes, keys) + + +@dataclass +class InstancesSchema(DictSchema): + def __call__(self, values): + image_size, fields = values[-1], values[:-1] + fields = super().__call__(fields) + return Instances(image_size, **fields) + + @classmethod + def flatten(cls, obj): + ret, schema = super().flatten(obj.get_fields()) + size = obj.image_size + if not isinstance(size, torch.Tensor): + size = torch.tensor(size) + return ret + (size,), schema + + +@dataclass +class TensorWrapSchema(Schema): + """ + For classes that are simple wrapper of tensors, e.g. + Boxes, RotatedBoxes, BitMasks + """ + + class_name: str + + def __call__(self, values): + return locate(self.class_name)(values[0]) + + @classmethod + def flatten(cls, obj): + return (obj.tensor,), cls(_convert_target_to_string(type(obj))) + + +# if more custom structures needed in the future, can allow +# passing in extra schemas for custom types +def flatten_to_tuple(obj): + """ + Flatten an object so it can be used for PyTorch tracing. + Also returns how to rebuild the original object from the flattened outputs. + + Returns: + res (tuple): the flattened results that can be used as tracing outputs + schema: an object with a ``__call__`` method such that ``schema(res) == obj``. + It is a pure dataclass that can be serialized. + """ + schemas = [ + ((str, bytes), IdentitySchema), + (list, ListSchema), + (tuple, TupleSchema), + (collections.abc.Mapping, DictSchema), + (Instances, InstancesSchema), + ((Boxes, ROIMasks), TensorWrapSchema), + ] + for klass, schema in schemas: + if isinstance(obj, klass): + F = schema + break + else: + F = IdentitySchema + + return F.flatten(obj) + + +class TracingAdapter(nn.Module): + """ + A model may take rich input/output format (e.g. dict or custom classes), + but `torch.jit.trace` requires tuple of tensors as input/output. + This adapter flattens input/output format of a model so it becomes traceable. + + It also records the necessary schema to rebuild model's inputs/outputs from flattened + inputs/outputs. + + Example: + :: + outputs = model(inputs) # inputs/outputs may be rich structure + adapter = TracingAdapter(model, inputs) + + # can now trace the model, with adapter.flattened_inputs, or another + # tuple of tensors with the same length and meaning + traced = torch.jit.trace(adapter, adapter.flattened_inputs) + + # traced model can only produce flattened outputs (tuple of tensors) + flattened_outputs = traced(*adapter.flattened_inputs) + # adapter knows the schema to convert it back (new_outputs == outputs) + new_outputs = adapter.outputs_schema(flattened_outputs) + """ + + flattened_inputs: Tuple[torch.Tensor] = None + """ + Flattened version of inputs given to this class's constructor. + """ + + inputs_schema: Schema = None + """ + Schema of the inputs given to this class's constructor. + """ + + outputs_schema: Schema = None + """ + Schema of the output produced by calling the given model with inputs. + """ + + def __init__( + self, + model: nn.Module, + inputs, + inference_func: Optional[Callable] = None, + allow_non_tensor: bool = False, + ): + """ + Args: + model: an nn.Module + inputs: An input argument or a tuple of input arguments used to call model. + After flattening, it has to only consist of tensors. + inference_func: a callable that takes (model, *inputs), calls the + model with inputs, and return outputs. By default it + is ``lambda model, *inputs: model(*inputs)``. Can be override + if you need to call the model differently. + allow_non_tensor: allow inputs/outputs to contain non-tensor objects. + This option will filter out non-tensor objects to make the + model traceable, but ``inputs_schema``/``outputs_schema`` cannot be + used anymore because inputs/outputs cannot be rebuilt from pure tensors. + This is useful when you're only interested in the single trace of + execution (e.g. for flop count), but not interested in + generalizing the traced graph to new inputs. + """ + super().__init__() + if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): + model = model.module + self.model = model + if not isinstance(inputs, tuple): + inputs = (inputs,) + self.inputs = inputs + self.allow_non_tensor = allow_non_tensor + + if inference_func is None: + inference_func = lambda model, *inputs: model(*inputs) # noqa + self.inference_func = inference_func + + self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs) + + if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs): + return + if self.allow_non_tensor: + self.flattened_inputs = tuple( + [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)] + ) + self.inputs_schema = None + else: + for input in self.flattened_inputs: + if not isinstance(input, torch.Tensor): + raise ValueError( + "Inputs for tracing must only contain tensors. " + f"Got a {type(input)} instead." + ) + + def forward(self, *args: torch.Tensor): + with torch.no_grad(), patch_builtin_len(): + if self.inputs_schema is not None: + inputs_orig_format = self.inputs_schema(args) + else: + if len(args) != len(self.flattened_inputs) or any( + x is not y for x, y in zip(args, self.flattened_inputs) + ): + raise ValueError( + "TracingAdapter does not contain valid inputs_schema." + " So it cannot generalize to other inputs and must be" + " traced with `.flattened_inputs`." + ) + inputs_orig_format = self.inputs + + outputs = self.inference_func(self.model, *inputs_orig_format) + flattened_outputs, schema = flatten_to_tuple(outputs) + + flattened_output_tensors = tuple( + [x for x in flattened_outputs if isinstance(x, torch.Tensor)] + ) + if len(flattened_output_tensors) < len(flattened_outputs): + if self.allow_non_tensor: + flattened_outputs = flattened_output_tensors + self.outputs_schema = None + else: + raise ValueError( + "Model cannot be traced because some model outputs " + "cannot flatten to tensors." + ) + else: # schema is valid + if self.outputs_schema is None: + self.outputs_schema = schema + else: + assert self.outputs_schema == schema, ( + "Model should always return outputs with the same " + "structure so it can be traced!" + ) + return flattened_outputs + + def _create_wrapper(self, traced_model): + """ + Return a function that has an input/output interface the same as the + original model, but it calls the given traced model under the hood. + """ + + def forward(*args): + flattened_inputs, _ = flatten_to_tuple(args) + flattened_outputs = traced_model(*flattened_inputs) + return self.outputs_schema(flattened_outputs) + + return forward diff --git a/data_processing/detectron2/detectron2/export/shared.py b/data_processing/detectron2/detectron2/export/shared.py new file mode 100644 index 0000000..53ba933 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/shared.py @@ -0,0 +1,1039 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import collections +import copy +import functools +import logging +import numpy as np +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest import mock +import caffe2.python.utils as putils +import torch +import torch.nn.functional as F +from caffe2.proto import caffe2_pb2 +from caffe2.python import core, net_drawer, workspace +from torch.nn.functional import interpolate as interp + +logger = logging.getLogger(__name__) + + +# ==== torch/utils_toffee/cast.py ======================================= + + +def to_device(t, device_str): + """ + This function is a replacement of .to(another_device) such that it allows the + casting to be traced properly by explicitly calling the underlying copy ops. + It also avoids introducing unncessary op when casting to the same device. + """ + src = t.device + dst = torch.device(device_str) + + if src == dst: + return t + elif src.type == "cuda" and dst.type == "cpu": + return torch.ops._caffe2.CopyGPUToCPU(t) + elif src.type == "cpu" and dst.type == "cuda": + return torch.ops._caffe2.CopyCPUToGPU(t) + else: + raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst)) + + +# ==== torch/utils_toffee/interpolate.py ======================================= + + +# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py +def BilinearInterpolation(tensor_in, up_scale): + assert up_scale % 2 == 0, "Scale should be even" + + def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) + + kernel_size = int(up_scale) * 2 + bil_filt = upsample_filt(kernel_size) + + dim = int(tensor_in.shape[1]) + kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32) + kernel[range(dim), range(dim), :, :] = bil_filt + + tensor_out = F.conv_transpose2d( + tensor_in, + weight=to_device(torch.Tensor(kernel), tensor_in.device), + bias=None, + stride=int(up_scale), + padding=int(up_scale / 2), + ) + + return tensor_out + + +# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if +# using dynamic `scale_factor` rather than static `size`. (T43166860) +# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly. +def onnx_compatibale_interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + # NOTE: The input dimensions are interpreted in the form: + # `mini-batch x channels x [optional depth] x [optional height] x width`. + if size is None and scale_factor is not None: + if input.dim() == 4: + if isinstance(scale_factor, (int, float)): + height_scale, width_scale = (scale_factor, scale_factor) + else: + assert isinstance(scale_factor, (tuple, list)) + assert len(scale_factor) == 2 + height_scale, width_scale = scale_factor + + assert not align_corners, "No matching C2 op for align_corners == True" + if mode == "nearest": + return torch.ops._caffe2.ResizeNearest( + input, order="NCHW", width_scale=width_scale, height_scale=height_scale + ) + elif mode == "bilinear": + logger.warning( + "Use F.conv_transpose2d for bilinear interpolate" + " because there's no such C2 op, this may cause significant" + " slowdown and the boundary pixels won't be as same as" + " using F.interpolate due to padding." + ) + assert height_scale == width_scale + return BilinearInterpolation(input, up_scale=height_scale) + logger.warning("Output size is not static, it might cause ONNX conversion issue") + + return interp(input, size, scale_factor, mode, align_corners) + + +def mock_torch_nn_functional_interpolate(): + def decorator(func): + @functools.wraps(func) + def _mock_torch_nn_functional_interpolate(*args, **kwargs): + if torch.onnx.is_in_onnx_export(): + with mock.patch( + "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate + ): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return _mock_torch_nn_functional_interpolate + + return decorator + + +# ==== torch/utils_caffe2/ws_utils.py ========================================== + + +class ScopedWS(object): + def __init__(self, ws_name, is_reset, is_cleanup=False): + self.ws_name = ws_name + self.is_reset = is_reset + self.is_cleanup = is_cleanup + self.org_ws = "" + + def __enter__(self): + self.org_ws = workspace.CurrentWorkspace() + if self.ws_name is not None: + workspace.SwitchWorkspace(self.ws_name, True) + if self.is_reset: + workspace.ResetWorkspace() + + return workspace + + def __exit__(self, *args): + if self.is_cleanup: + workspace.ResetWorkspace() + if self.ws_name is not None: + workspace.SwitchWorkspace(self.org_ws) + + +def fetch_any_blob(name): + bb = None + try: + bb = workspace.FetchBlob(name) + except TypeError: + bb = workspace.FetchInt8Blob(name) + except Exception as e: + logger.error("Get blob {} error: {}".format(name, e)) + + return bb + + +# ==== torch/utils_caffe2/protobuf.py ========================================== + + +def get_pb_arg(pb, arg_name): + for x in pb.arg: + if x.name == arg_name: + return x + return None + + +def get_pb_arg_valf(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.f if arg is not None else default_val + + +def get_pb_arg_floats(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(map(float, arg.floats)) if arg is not None else default_val + + +def get_pb_arg_ints(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(map(int, arg.ints)) if arg is not None else default_val + + +def get_pb_arg_vali(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.i if arg is not None else default_val + + +def get_pb_arg_vals(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.s if arg is not None else default_val + + +def get_pb_arg_valstrings(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(arg.strings) if arg is not None else default_val + + +def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False): + arg = get_pb_arg(pb, arg_name) + if arg is None: + arg = putils.MakeArgument(arg_name, arg_value) + assert hasattr(arg, arg_attr) + pb.arg.extend([arg]) + if allow_override and getattr(arg, arg_attr) != arg_value: + logger.warning( + "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value) + ) + setattr(arg, arg_attr, arg_value) + else: + assert arg is not None + assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format( + getattr(arg, arg_attr), arg_value + ) + + +def _create_const_fill_op_from_numpy(name, tensor, device_option=None): + assert type(tensor) == np.ndarray + kTypeNameMapper = { + np.dtype("float32"): "GivenTensorFill", + np.dtype("int32"): "GivenTensorIntFill", + np.dtype("int64"): "GivenTensorInt64Fill", + np.dtype("uint8"): "GivenTensorStringFill", + } + + args_dict = {} + if tensor.dtype == np.dtype("uint8"): + args_dict.update({"values": [str(tensor.data)], "shape": [1]}) + else: + args_dict.update({"values": tensor, "shape": tensor.shape}) + + if device_option is not None: + args_dict["device_option"] = device_option + + return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict) + + +def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor): + assert type(int8_tensor) == workspace.Int8Tensor + kTypeNameMapper = { + np.dtype("int32"): "Int8GivenIntTensorFill", + np.dtype("uint8"): "Int8GivenTensorFill", + } + + tensor = int8_tensor.data + assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")] + values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor + + return core.CreateOperator( + kTypeNameMapper[tensor.dtype], + [], + [name], + values=values, + shape=tensor.shape, + Y_scale=int8_tensor.scale, + Y_zero_point=int8_tensor.zero_point, + ) + + +def create_const_fill_op( + name: str, + blob: Union[np.ndarray, workspace.Int8Tensor], + device_option: Optional[caffe2_pb2.DeviceOption] = None, +) -> caffe2_pb2.OperatorDef: + """ + Given a blob object, return the Caffe2 operator that creates this blob + as constant. Currently support NumPy tensor and Caffe2 Int8Tensor. + """ + + tensor_type = type(blob) + assert tensor_type in [ + np.ndarray, + workspace.Int8Tensor, + ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format( + name, type(blob) + ) + + if tensor_type == np.ndarray: + return _create_const_fill_op_from_numpy(name, blob, device_option) + elif tensor_type == workspace.Int8Tensor: + assert device_option is None + return _create_const_fill_op_from_c2_int8_tensor(name, blob) + + +def construct_init_net_from_params( + params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None +) -> caffe2_pb2.NetDef: + """ + Construct the init_net from params dictionary + """ + init_net = caffe2_pb2.NetDef() + device_options = device_options or {} + for name, blob in params.items(): + if isinstance(blob, str): + logger.warning( + ( + "Blob {} with type {} is not supported in generating init net," + " skipped.".format(name, type(blob)) + ) + ) + continue + init_net.op.extend( + [create_const_fill_op(name, blob, device_option=device_options.get(name, None))] + ) + init_net.external_output.append(name) + return init_net + + +def get_producer_map(ssa): + """ + Return dict from versioned blob to (i, j), + where i is index of producer op, j is the index of output of that op. + """ + producer_map = {} + for i in range(len(ssa)): + outputs = ssa[i][1] + for j, outp in enumerate(outputs): + producer_map[outp] = (i, j) + return producer_map + + +def get_consumer_map(ssa): + """ + Return dict from versioned blob to list of (i, j), + where i is index of consumer op, j is the index of input of that op. + """ + consumer_map = collections.defaultdict(list) + for i in range(len(ssa)): + inputs = ssa[i][0] + for j, inp in enumerate(inputs): + consumer_map[inp].append((i, j)) + return consumer_map + + +def get_params_from_init_net( + init_net: caffe2_pb2.NetDef, +) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]: + """ + Take the output blobs from init_net by running it. + Outputs: + params: dict from blob name to numpy array + device_options: dict from blob name to the device option of its creating op + """ + # NOTE: this assumes that the params is determined by producer op with the + # only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor. + def _get_device_option(producer_op): + if producer_op.type == "CopyGPUToCPU": + return caffe2_pb2.DeviceOption() + else: + return producer_op.device_option + + with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws: + ws.RunNetOnce(init_net) + params = {b: fetch_any_blob(b) for b in init_net.external_output} + ssa, versions = core.get_ssa(init_net) + producer_map = get_producer_map(ssa) + device_options = { + b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]]) + for b in init_net.external_output + } + return params, device_options + + +def _updater_raise(op, input_types, output_types): + raise RuntimeError( + "Failed to apply updater for op {} given input_types {} and" + " output_types {}".format(op, input_types, output_types) + ) + + +def _generic_status_identifier( + predict_net: caffe2_pb2.NetDef, + status_updater: Callable, + known_status: Dict[Tuple[str, int], Any], +) -> Dict[Tuple[str, int], Any]: + """ + Statically infer the status of each blob, the status can be such as device type + (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here + is versioned blob (Tuple[str, int]) in the format compatible with ssa. + Inputs: + predict_net: the caffe2 network + status_updater: a callable, given an op and the status of its input/output, + it returns the updated status of input/output. `None` is used for + representing unknown status. + known_status: a dict containing known status, used as initialization. + Outputs: + A dict mapping from versioned blob to its status + """ + ssa, versions = core.get_ssa(predict_net) + versioned_ext_input = [(b, 0) for b in predict_net.external_input] + versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output] + all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa]) + + allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output) + assert all(k in allowed_vbs for k in known_status) + assert all(v is not None for v in known_status.values()) + _known_status = copy.deepcopy(known_status) + + def _check_and_update(key, value): + assert value is not None + if key in _known_status: + if not _known_status[key] == value: + raise RuntimeError( + "Confilict status for {}, existing status {}, new status {}".format( + key, _known_status[key], value + ) + ) + _known_status[key] = value + + def _update_i(op, ssa_i): + versioned_inputs = ssa_i[0] + versioned_outputs = ssa_i[1] + + inputs_status = [_known_status.get(b, None) for b in versioned_inputs] + outputs_status = [_known_status.get(b, None) for b in versioned_outputs] + + new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status) + + for versioned_blob, status in zip( + versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status + ): + if status is not None: + _check_and_update(versioned_blob, status) + + for op, ssa_i in zip(predict_net.op, ssa): + _update_i(op, ssa_i) + for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)): + _update_i(op, ssa_i) + + # NOTE: This strictly checks all the blob from predict_net must be assgined + # a known status. However sometimes it's impossible (eg. having deadend op), + # we may relax this constraint if + for k in all_versioned_blobs: + if k not in _known_status: + raise NotImplementedError( + "Can not infer the status for {}. Currently only support the case where" + " a single forward and backward pass can identify status for all blobs.".format(k) + ) + + return _known_status + + +def infer_device_type( + predict_net: caffe2_pb2.NetDef, + known_status: Dict[Tuple[str, int], Any], + device_name_style: str = "caffe2", +) -> Dict[Tuple[str, int], str]: + """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob""" + + assert device_name_style in ["caffe2", "pytorch"] + _CPU_STR = "cpu" + _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda" + + def _copy_cpu_to_gpu_updater(op, input_types, output_types): + if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR: + _updater_raise(op, input_types, output_types) + return ([_CPU_STR], [_GPU_STR]) + + def _copy_gpu_to_cpu_updater(op, input_types, output_types): + if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR: + _updater_raise(op, input_types, output_types) + return ([_GPU_STR], [_CPU_STR]) + + def _other_ops_updater(op, input_types, output_types): + non_none_types = [x for x in input_types + output_types if x is not None] + if len(non_none_types) > 0: + the_type = non_none_types[0] + if not all(x == the_type for x in non_none_types): + _updater_raise(op, input_types, output_types) + else: + the_type = None + return ([the_type for _ in op.input], [the_type for _ in op.output]) + + def _device_updater(op, *args, **kwargs): + return { + "CopyCPUToGPU": _copy_cpu_to_gpu_updater, + "CopyGPUToCPU": _copy_gpu_to_cpu_updater, + }.get(op.type, _other_ops_updater)(op, *args, **kwargs) + + return _generic_status_identifier(predict_net, _device_updater, known_status) + + +# ==== torch/utils_caffe2/vis.py =============================================== + + +def _modify_blob_names(ops, blob_rename_f): + ret = [] + + def _replace_list(blob_list, replaced_list): + del blob_list[:] + blob_list.extend(replaced_list) + + for x in ops: + cur = copy.deepcopy(x) + _replace_list(cur.input, list(map(blob_rename_f, cur.input))) + _replace_list(cur.output, list(map(blob_rename_f, cur.output))) + ret.append(cur) + + return ret + + +def _rename_blob(name, blob_sizes, blob_ranges): + def _list_to_str(bsize): + ret = ", ".join([str(x) for x in bsize]) + ret = "[" + ret + "]" + return ret + + ret = name + if blob_sizes is not None and name in blob_sizes: + ret += "\n" + _list_to_str(blob_sizes[name]) + if blob_ranges is not None and name in blob_ranges: + ret += "\n" + _list_to_str(blob_ranges[name]) + + return ret + + +# graph_name could not contain word 'graph' +def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None): + blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges) + return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f) + + +def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None): + graph = None + ops = net.op + if blob_rename_func is not None: + ops = _modify_blob_names(ops, blob_rename_func) + if not op_only: + graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB") + else: + graph = net_drawer.GetPydotGraphMinimal( + ops, graph_name, rankdir="TB", minimal_dependency=True + ) + + try: + par_dir = os.path.dirname(file_name) + if not os.path.exists(par_dir): + os.makedirs(par_dir) + + format = os.path.splitext(os.path.basename(file_name))[-1] + if format == ".png": + graph.write_png(file_name) + elif format == ".pdf": + graph.write_pdf(file_name) + elif format == ".svg": + graph.write_svg(file_name) + else: + print("Incorrect format {}".format(format)) + except Exception as e: + print("Error when writing graph to image {}".format(e)) + + return graph + + +# ==== torch/utils_toffee/aten_to_caffe2.py ==================================== + + +def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef): + """ + For ONNX exported model, GroupNorm will be represented as ATen op, + this can be a drop in replacement from ATen to GroupNorm + """ + count = 0 + for op in predict_net.op: + if op.type == "ATen": + op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3 + if op_name and op_name.decode() == "group_norm": + op.arg.remove(get_pb_arg(op, "operator")) + + if get_pb_arg_vali(op, "cudnn_enabled", None): + op.arg.remove(get_pb_arg(op, "cudnn_enabled")) + + num_groups = get_pb_arg_vali(op, "num_groups", None) + if num_groups is not None: + op.arg.remove(get_pb_arg(op, "num_groups")) + check_set_pb_arg(op, "group", "i", num_groups) + + op.type = "GroupNorm" + count += 1 + if count > 1: + logger.info("Replaced {} ATen operator to GroupNormOp".format(count)) + + +# ==== torch/utils_toffee/alias.py ============================================= + + +def alias(x, name, is_backward=False): + if not torch.onnx.is_in_onnx_export(): + return x + assert isinstance(x, torch.Tensor) + return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward) + + +def fuse_alias_placeholder(predict_net, init_net): + """Remove AliasWithName placeholder and rename the input/output of it""" + # First we finish all the re-naming + for i, op in enumerate(predict_net.op): + if op.type == "AliasWithName": + assert len(op.input) == 1 + assert len(op.output) == 1 + name = get_pb_arg_vals(op, "name", None).decode() + is_backward = bool(get_pb_arg_vali(op, "is_backward", 0)) + rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward) + rename_op_output(predict_net, i, 0, name) + + # Remove AliasWithName, should be very safe since it's a non-op + new_ops = [] + for op in predict_net.op: + if op.type != "AliasWithName": + new_ops.append(op) + else: + # safety check + assert op.input == op.output + assert op.input[0] == op.arg[0].s.decode() + del predict_net.op[:] + predict_net.op.extend(new_ops) + + +# ==== torch/utils_caffe2/graph_transform.py =================================== + + +class IllegalGraphTransformError(ValueError): + """When a graph transform function call can't be executed.""" + + +def _rename_versioned_blob_in_proto( + proto: caffe2_pb2.NetDef, + old_name: str, + new_name: str, + version: int, + ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]], + start_versions: Dict[str, int], + end_versions: Dict[str, int], +): + """In given proto, rename all blobs with matched version""" + # Operater list + for op, i_th_ssa in zip(proto.op, ssa): + versioned_inputs, versioned_outputs = i_th_ssa + for i in range(len(op.input)): + if versioned_inputs[i] == (old_name, version): + op.input[i] = new_name + for i in range(len(op.output)): + if versioned_outputs[i] == (old_name, version): + op.output[i] = new_name + # external_input + if start_versions.get(old_name, 0) == version: + for i in range(len(proto.external_input)): + if proto.external_input[i] == old_name: + proto.external_input[i] = new_name + # external_output + if end_versions.get(old_name, 0) == version: + for i in range(len(proto.external_output)): + if proto.external_output[i] == old_name: + proto.external_output[i] = new_name + + +def rename_op_input( + predict_net: caffe2_pb2.NetDef, + init_net: caffe2_pb2.NetDef, + op_id: int, + input_id: int, + new_name: str, + from_producer: bool = False, +): + """ + Rename the op_id-th operator in predict_net, change it's input_id-th input's + name to the new_name. It also does automatic re-route and change + external_input and init_net if necessary. + - It requires the input is only consumed by this op. + - This function modifies predict_net and init_net in-place. + - When from_producer is enable, this also updates other operators that consumes + the same input. Be cautious because may trigger unintended behavior. + """ + assert isinstance(predict_net, caffe2_pb2.NetDef) + assert isinstance(init_net, caffe2_pb2.NetDef) + + init_net_ssa, init_net_versions = core.get_ssa(init_net) + predict_net_ssa, predict_net_versions = core.get_ssa( + predict_net, copy.deepcopy(init_net_versions) + ) + + versioned_inputs, versioned_outputs = predict_net_ssa[op_id] + old_name, version = versioned_inputs[input_id] + + if from_producer: + producer_map = get_producer_map(predict_net_ssa) + if not (old_name, version) in producer_map: + raise NotImplementedError( + "Can't find producer, the input {} is probably from" + " init_net, this is not supported yet.".format(old_name) + ) + producer = producer_map[(old_name, version)] + rename_op_output(predict_net, producer[0], producer[1], new_name) + return + + def contain_targets(op_ssa): + return (old_name, version) in op_ssa[0] + + is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa] + if sum(is_consumer) > 1: + raise IllegalGraphTransformError( + ( + "Input '{}' of operator(#{}) are consumed by other ops, please use" + + " rename_op_output on the producer instead. Offending op: \n{}" + ).format(old_name, op_id, predict_net.op[op_id]) + ) + + # update init_net + _rename_versioned_blob_in_proto( + init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions + ) + # update predict_net + _rename_versioned_blob_in_proto( + predict_net, + old_name, + new_name, + version, + predict_net_ssa, + init_net_versions, + predict_net_versions, + ) + + +def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str): + """ + Rename the op_id-th operator in predict_net, change it's output_id-th input's + name to the new_name. It also does automatic re-route and change + external_output and if necessary. + - It allows multiple consumers of its output. + - This function modifies predict_net in-place, doesn't need init_net. + """ + assert isinstance(predict_net, caffe2_pb2.NetDef) + + ssa, blob_versions = core.get_ssa(predict_net) + + versioned_inputs, versioned_outputs = ssa[op_id] + old_name, version = versioned_outputs[output_id] + + # update predict_net + _rename_versioned_blob_in_proto( + predict_net, old_name, new_name, version, ssa, {}, blob_versions + ) + + +def get_sub_graph_external_input_output( + predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int] +) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]: + """ + Return the list of external input/output of sub-graph, + each element is tuple of the name and corresponding version in predict_net. + + external input/output is defined the same way as caffe2 NetDef. + """ + ssa, versions = core.get_ssa(predict_net) + + all_inputs = [] + all_outputs = [] + for op_id in sub_graph_op_indices: + all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs] + all_outputs += list(ssa[op_id][1]) # ssa output won't repeat + + # for versioned blobs, external inputs are just those blob in all_inputs + # but not in all_outputs + ext_inputs = [inp for inp in all_inputs if inp not in all_outputs] + + # external outputs are essentially outputs of this subgraph that are used + # outside of this sub-graph (including predict_net.external_output) + all_other_inputs = sum( + (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices), + [(outp, versions[outp]) for outp in predict_net.external_output], + ) + ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)] + + return ext_inputs, ext_outputs + + +class DiGraph: + """A DAG representation of caffe2 graph, each vertice is a versioned blob.""" + + def __init__(self): + self.vertices = set() + self.graph = collections.defaultdict(list) + + def add_edge(self, u, v): + self.graph[u].append(v) + self.vertices.add(u) + self.vertices.add(v) + + # grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/ + def get_all_paths(self, s, d): + visited = {k: False for k in self.vertices} + path = [] + all_paths = [] + + def _get_all_paths_util(graph, u, d, visited, path): + visited[u] = True + path.append(u) + if u == d: + all_paths.append(copy.deepcopy(path)) + else: + for i in graph[u]: + if not visited[i]: + _get_all_paths_util(graph, i, d, visited, path) + path.pop() + visited[u] = False + + _get_all_paths_util(self.graph, s, d, visited, path) + return all_paths + + @staticmethod + def from_ssa(ssa): + graph = DiGraph() + for op_id in range(len(ssa)): + for inp in ssa[op_id][0]: + for outp in ssa[op_id][1]: + graph.add_edge(inp, outp) + return graph + + +def _get_dependency_chain(ssa, versioned_target, versioned_source): + """ + Return the index list of relevant operator to produce target blob from source blob, + if there's no dependency, return empty list. + """ + + # finding all paths between nodes can be O(N!), thus we can only search + # in the subgraph using the op starting from the first consumer of source blob + # to the producer of the target blob. + consumer_map = get_consumer_map(ssa) + producer_map = get_producer_map(ssa) + start_op = min(x[0] for x in consumer_map[versioned_source]) - 15 + end_op = ( + producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op + ) + sub_graph_ssa = ssa[start_op : end_op + 1] + if len(sub_graph_ssa) > 30: + logger.warning( + "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it" + " might take non-trival time to find all paths between them.".format( + versioned_source, versioned_target, start_op, end_op + ) + ) + + dag = DiGraph.from_ssa(sub_graph_ssa) + paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends + ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths] + return sorted(set().union(*[set(ops) for ops in ops_in_paths])) + + +def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]: + """ + Idenfity the reshape sub-graph in a protobuf. + The reshape sub-graph is defined as matching the following pattern: + + (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐ + └-------------------------------------------> Reshape -> (output_blob) + + Return: + List of sub-graphs, each sub-graph is represented as a list of indices + of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape] + """ + + ssa, _ = core.get_ssa(predict_net) + + ret = [] + for i, op in enumerate(predict_net.op): + if op.type == "Reshape": + assert len(op.input) == 2 + input_ssa = ssa[i][0] + data_source = input_ssa[0] + shape_source = input_ssa[1] + op_indices = _get_dependency_chain(ssa, shape_source, data_source) + ret.append(op_indices + [i]) + return ret + + +def remove_reshape_for_fc(predict_net, params): + """ + In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape + a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping + doesn't work well with ONNX and Int8 tools, and cause using extra + ops (eg. ExpandDims) that might not be available on mobile. + Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape + after exporting ONNX model. + """ + from caffe2.python import core + + # find all reshape sub-graph that can be removed, which is now all Reshape + # sub-graph whose output is only consumed by FC. + # TODO: to make it safer, we may need the actually value to better determine + # if a Reshape before FC is removable. + reshape_sub_graphs = identify_reshape_sub_graph(predict_net) + sub_graphs_to_remove = [] + for reshape_sub_graph in reshape_sub_graphs: + reshape_op_id = reshape_sub_graph[-1] + assert predict_net.op[reshape_op_id].type == "Reshape" + ssa, _ = core.get_ssa(predict_net) + reshape_output = ssa[reshape_op_id][1][0] + consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]] + if all(predict_net.op[consumer].type == "FC" for consumer in consumers): + # safety check if the sub-graph is isolated, for this reshape sub-graph, + # it means it has one non-param external input and one external output. + ext_inputs, ext_outputs = get_sub_graph_external_input_output( + predict_net, reshape_sub_graph + ) + non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] + if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1: + sub_graphs_to_remove.append(reshape_sub_graph) + + # perform removing subgraph by: + # 1: rename the Reshape's output to its input, then the graph can be + # seen as in-place itentify, meaning whose external input/output are the same. + # 2: simply remove those ops. + remove_op_ids = [] + params_to_remove = [] + for sub_graph in sub_graphs_to_remove: + logger.info( + "Remove Reshape sub-graph:\n{}".format( + "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph]) + ) + ) + reshape_op_id = sub_graph[-1] + new_reshap_output = predict_net.op[reshape_op_id].input[0] + rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output) + ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph) + non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] + params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0] + assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1 + assert ext_outputs[0][0] == non_params_ext_inputs[0][0] + assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1 + remove_op_ids.extend(sub_graph) + params_to_remove.extend(params_ext_inputs) + + predict_net = copy.deepcopy(predict_net) + new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids] + del predict_net.op[:] + predict_net.op.extend(new_ops) + for versioned_params in params_to_remove: + name = versioned_params[0] + logger.info("Remove params: {} from init_net and predict_net.external_input".format(name)) + del params[name] + predict_net.external_input.remove(name) + + return predict_net, params + + +def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef): + """ + In-place fuse extra copy ops between cpu/gpu for the following case: + a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1 + -CopyBToA> c2 -NextOp2-> d2 + The fused network will look like: + a -NextOp1-> d1 + -NextOp2-> d2 + """ + + _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"] + + def _fuse_once(predict_net): + ssa, blob_versions = core.get_ssa(predict_net) + consumer_map = get_consumer_map(ssa) + versioned_external_output = [ + (name, blob_versions[name]) for name in predict_net.external_output + ] + + for op_id, op in enumerate(predict_net.op): + if op.type in _COPY_OPS: + fw_copy_versioned_output = ssa[op_id][1][0] + consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]] + reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)] + + is_fusable = ( + len(consumer_ids) > 0 + and fw_copy_versioned_output not in versioned_external_output + and all( + predict_net.op[_op_id].type == reverse_op_type + and ssa[_op_id][1][0] not in versioned_external_output + for _op_id in consumer_ids + ) + ) + + if is_fusable: + for rv_copy_op_id in consumer_ids: + # making each NextOp uses "a" directly and removing Copy ops + rs_copy_versioned_output = ssa[rv_copy_op_id][1][0] + next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0] + predict_net.op[next_op_id].input[inp_id] = op.input[0] + # remove CopyOps + new_ops = [ + op + for i, op in enumerate(predict_net.op) + if i != op_id and i not in consumer_ids + ] + del predict_net.op[:] + predict_net.op.extend(new_ops) + return True + + return False + + # _fuse_once returns False is nothing can be fused + while _fuse_once(predict_net): + pass + + +def remove_dead_end_ops(net_def: caffe2_pb2.NetDef): + """remove ops if its output is not used or not in external_output""" + ssa, versions = core.get_ssa(net_def) + versioned_external_output = [(name, versions[name]) for name in net_def.external_output] + consumer_map = get_consumer_map(ssa) + removed_op_ids = set() + + def _is_dead_end(versioned_blob): + return not ( + versioned_blob in versioned_external_output + or ( + len(consumer_map[versioned_blob]) > 0 + and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob]) + ) + ) + + for i, ssa_i in reversed(list(enumerate(ssa))): + versioned_outputs = ssa_i[1] + if all(_is_dead_end(outp) for outp in versioned_outputs): + removed_op_ids.add(i) + + # simply removing those deadend ops should have no effect to external_output + new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids] + del net_def.op[:] + net_def.op.extend(new_ops) diff --git a/data_processing/detectron2/detectron2/export/torchscript.py b/data_processing/detectron2/detectron2/export/torchscript.py new file mode 100644 index 0000000..24fe59b --- /dev/null +++ b/data_processing/detectron2/detectron2/export/torchscript.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import torch + +from detectron2.utils.file_io import PathManager + +from .torchscript_patch import freeze_training_mode, patch_instances + +__all__ = ["scripting_with_instances", "dump_torchscript_IR"] + + +def scripting_with_instances(model, fields): + """ + Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since + attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult + for scripting to support it out of the box. This function is made to support scripting + a model that uses :class:`Instances`. It does the following: + + 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``, + but with all attributes been "static". + The attributes need to be statically declared in the ``fields`` argument. + 2. Register ``new_Instances``, and force scripting compiler to + use it when trying to compile ``Instances``. + + After this function, the process will be reverted. User should be able to script another model + using different fields. + + Example: + Assume that ``Instances`` in the model consist of two attributes named + ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and + :class:`Tensor` respectively during inference. You can call this function like: + :: + fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} + torchscipt_model = scripting_with_instances(model, fields) + + Note: + It only support models in evaluation mode. + + Args: + model (nn.Module): The input model to be exported by scripting. + fields (Dict[str, type]): Attribute names and corresponding type that + ``Instances`` will use in the model. Note that all attributes used in ``Instances`` + need to be added, regardless of whether they are inputs/outputs of the model. + Data type not defined in detectron2 is not supported for now. + + Returns: + torch.jit.ScriptModule: the model in torchscript format + """ + assert ( + not model.training + ), "Currently we only support exporting models in evaluation mode to torchscript" + + with freeze_training_mode(model), patch_instances(fields): + scripted_model = torch.jit.script(model) + return scripted_model + + +# alias for old name +export_torchscript_with_instances = scripting_with_instances + + +def dump_torchscript_IR(model, dir): + """ + Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, + inlined graph). Useful for debugging. + + Args: + model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module + dir (str): output directory to dump files. + """ + dir = os.path.expanduser(dir) + PathManager.mkdirs(dir) + + def _get_script_mod(mod): + if isinstance(mod, torch.jit.TracedModule): + return mod._actual_script_module + return mod + + # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code + with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f: + + def get_code(mod): + # Try a few ways to get code using private attributes. + try: + # This contains more information than just `mod.code` + return _get_script_mod(mod)._c.code + except AttributeError: + pass + try: + return mod.code + except AttributeError: + return None + + def dump_code(prefix, mod): + code = get_code(mod) + name = prefix or "root model" + if code is None: + f.write(f"Could not found code for {name} (type={mod.original_name})\n") + f.write("\n") + else: + f.write(f"\nCode for {name}, type={mod.original_name}:\n") + f.write(code) + f.write("\n") + f.write("-" * 80) + + for name, m in mod.named_children(): + dump_code(prefix + "." + name, m) + + if isinstance(model, torch.jit.ScriptFunction): + f.write(get_code(model)) + else: + dump_code("", model) + + def _get_graph(model): + try: + # Recursively dump IR of all modules + return _get_script_mod(model)._c.dump_to_str(True, False, False) + except AttributeError: + return model.graph.str() + + with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f: + f.write(_get_graph(model)) + + # Dump IR of the entire graph (all submodules inlined) + with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f: + f.write(str(model.inlined_graph)) + + if not isinstance(model, torch.jit.ScriptFunction): + # Dump the model structure in pytorch style + with PathManager.open(os.path.join(dir, "model.txt"), "w") as f: + f.write(str(model)) diff --git a/data_processing/detectron2/detectron2/export/torchscript_patch.py b/data_processing/detectron2/detectron2/export/torchscript_patch.py new file mode 100644 index 0000000..da9b324 --- /dev/null +++ b/data_processing/detectron2/detectron2/export/torchscript_patch.py @@ -0,0 +1,406 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import sys +import tempfile +from contextlib import ExitStack, contextmanager +from copy import deepcopy +from unittest import mock +import torch +from torch import nn + +# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964 +import detectron2 # noqa F401 +from detectron2.structures import Boxes, Instances +from detectron2.utils.env import _import_file + +_counter = 0 + + +def _clear_jit_cache(): + from torch.jit._recursive import concrete_type_store + from torch.jit._state import _jit_caching_layer + + concrete_type_store.type_store.clear() # for modules + _jit_caching_layer.clear() # for free functions + + +def _add_instances_conversion_methods(newInstances): + """ + Add from_instances methods to the scripted Instances class. + """ + cls_name = newInstances.__name__ + + @torch.jit.unused + def from_instances(instances: Instances): + """ + Create scripted Instances from original Instances + """ + fields = instances.get_fields() + image_size = instances.image_size + ret = newInstances(image_size) + for name, val in fields.items(): + assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}" + setattr(ret, name, deepcopy(val)) + return ret + + newInstances.from_instances = from_instances + + +@contextmanager +def patch_instances(fields): + """ + A contextmanager, under which the Instances class in detectron2 is replaced + by a statically-typed scriptable class, defined by `fields`. + See more in `scripting_with_instances`. + """ + + with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False + ) as f: + try: + # Objects that use Instances should not reuse previously-compiled + # results in cache, because `Instances` could be a new class each time. + _clear_jit_cache() + + cls_name, s = _gen_instance_module(fields) + f.write(s) + f.flush() + f.close() + + module = _import(f.name) + new_instances = getattr(module, cls_name) + _ = torch.jit.script(new_instances) + # let torchscript think Instances was scripted already + Instances.__torch_script_class__ = True + # let torchscript find new_instances when looking for the jit type of Instances + Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances) + + _add_instances_conversion_methods(new_instances) + yield new_instances + finally: + try: + del Instances.__torch_script_class__ + del Instances._jit_override_qualname + except AttributeError: + pass + sys.modules.pop(module.__name__) + + +def _gen_instance_class(fields): + """ + Args: + fields (dict[name: type]) + """ + + class _FieldType: + def __init__(self, name, type_): + assert isinstance(name, str), f"Field name must be str, got {name}" + self.name = name + self.type_ = type_ + self.annotation = f"{type_.__module__}.{type_.__name__}" + + fields = [_FieldType(k, v) for k, v in fields.items()] + + def indent(level, s): + return " " * 4 * level + s + + lines = [] + + global _counter + _counter += 1 + + cls_name = "ScriptedInstances{}".format(_counter) + + field_names = tuple(x.name for x in fields) + extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields]) + lines.append( + f""" +class {cls_name}: + def __init__(self, image_size: Tuple[int, int], {extra_args}): + self.image_size = image_size + self._field_names = {field_names} +""" + ) + + for f in fields: + lines.append( + indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})") + ) + + for f in fields: + lines.append( + f""" + @property + def {f.name}(self) -> {f.annotation}: + # has to use a local for type refinement + # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement + t = self._{f.name} + assert t is not None, "{f.name} is None and cannot be accessed!" + return t + + @{f.name}.setter + def {f.name}(self, value: {f.annotation}) -> None: + self._{f.name} = value +""" + ) + + # support method `__len__` + lines.append( + """ + def __len__(self) -> int: +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + return len(t) +""" + ) + lines.append( + """ + raise NotImplementedError("Empty Instances does not support __len__!") +""" + ) + + # support method `has` + lines.append( + """ + def has(self, name: str) -> bool: +""" + ) + for f in fields: + lines.append( + f""" + if name == "{f.name}": + return self._{f.name} is not None +""" + ) + lines.append( + """ + return False +""" + ) + + # support method `to` + none_args = ", None" * len(fields) + lines.append( + f""" + def to(self, device: torch.device) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + if hasattr(f.type_, "to"): + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret._{f.name} = t.to(device) +""" + ) + else: + # For now, ignore fields that cannot be moved to devices. + # Maybe can support other tensor-like classes (e.g. __torch_function__) + pass + lines.append( + """ + return ret +""" + ) + + # support method `getitem` + none_args = ", None" * len(fields) + lines.append( + f""" + def __getitem__(self, item) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret._{f.name} = t[item] +""" + ) + lines.append( + """ + return ret +""" + ) + + # support method `cat` + # this version does not contain checks that all instances have same size and fields + none_args = ", None" * len(fields) + lines.append( + f""" + def cat(self, instances: List["{cls_name}"]) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + values: List[{f.annotation}] = [x.{f.name} for x in instances] + if torch.jit.isinstance(t, torch.Tensor): + ret._{f.name} = torch.cat(values, dim=0) + else: + ret._{f.name} = t.cat(values) +""" + ) + lines.append( + """ + return ret""" + ) + + # support method `get_fields()` + lines.append( + """ + def get_fields(self) -> Dict[str, Tensor]: + ret = {} + """ + ) + for f in fields: + if f.type_ == Boxes: + stmt = "t.tensor" + elif f.type_ == torch.Tensor: + stmt = "t" + else: + stmt = f'assert False, "unsupported type {str(f.type_)}"' + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret["{f.name}"] = {stmt} + """ + ) + lines.append( + """ + return ret""" + ) + return cls_name, os.linesep.join(lines) + + +def _gen_instance_module(fields): + # TODO: find a more automatic way to enable import of other classes + s = """ +from copy import deepcopy +import torch +from torch import Tensor +import typing +from typing import * + +import detectron2 +from detectron2.structures import Boxes, Instances + +""" + + cls_name, cls_def = _gen_instance_class(fields) + s += cls_def + return cls_name, s + + +def _import(path): + return _import_file( + "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True + ) + + +@contextmanager +def patch_builtin_len(modules=()): + """ + Patch the builtin len() function of a few detectron2 modules + to use __len__ instead, because __len__ does not convert values to + integers and therefore is friendly to tracing. + + Args: + modules (list[stsr]): names of extra modules to patch len(), in + addition to those in detectron2. + """ + + def _new_len(obj): + return obj.__len__() + + with ExitStack() as stack: + MODULES = [ + "detectron2.modeling.roi_heads.fast_rcnn", + "detectron2.modeling.roi_heads.mask_head", + "detectron2.modeling.roi_heads.keypoint_head", + ] + list(modules) + ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES] + for m in ctxs: + m.side_effect = _new_len + yield + + +def patch_nonscriptable_classes(): + """ + Apply patches on a few nonscriptable detectron2 classes. + Should not have side-effects on eager usage. + """ + # __prepare_scriptable__ can also be added to models for easier maintenance. + # But it complicates the clean model code. + + from detectron2.modeling.backbone import ResNet, FPN + + # Due to https://github.com/pytorch/pytorch/issues/36061, + # we change backbone to use ModuleList for scripting. + # (note: this changes param names in state_dict) + + def prepare_resnet(self): + ret = deepcopy(self) + ret.stages = nn.ModuleList(ret.stages) + for k in self.stage_names: + delattr(ret, k) + return ret + + ResNet.__prepare_scriptable__ = prepare_resnet + + def prepare_fpn(self): + ret = deepcopy(self) + ret.lateral_convs = nn.ModuleList(ret.lateral_convs) + ret.output_convs = nn.ModuleList(ret.output_convs) + for name, _ in self.named_children(): + if name.startswith("fpn_"): + delattr(ret, name) + return ret + + FPN.__prepare_scriptable__ = prepare_fpn + + # Annotate some attributes to be constants for the purpose of scripting, + # even though they are not constants in eager mode. + from detectron2.modeling.roi_heads import StandardROIHeads + + if hasattr(StandardROIHeads, "__annotations__"): + # copy first to avoid editing annotations of base class + StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__) + StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool] + StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool] + + +# These patches are not supposed to have side-effects. +patch_nonscriptable_classes() + + +@contextmanager +def freeze_training_mode(model): + """ + A context manager that annotates the "training" attribute of every submodule + to constant, so that the training codepath in these modules can be + meta-compiled away. Upon exiting, the annotations are reverted. + """ + classes = {type(x) for x in model.modules()} + # __constants__ is the old way to annotate constants and not compatible + # with __annotations__ . + classes = {x for x in classes if not hasattr(x, "__constants__")} + for cls in classes: + cls.__annotations__["training"] = torch.jit.Final[bool] + yield + for cls in classes: + cls.__annotations__["training"] = bool diff --git a/data_processing/detectron2/detectron2/layers/__init__.py b/data_processing/detectron2/detectron2/layers/__init__.py new file mode 100644 index 0000000..761a3d1 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm, CycleBatchNormList +from .deform_conv import DeformConv, ModulatedDeformConv +from .mask_ops import paste_masks_in_image +from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated +from .roi_align import ROIAlign, roi_align +from .roi_align_rotated import ROIAlignRotated, roi_align_rotated +from .shape_spec import ShapeSpec +from .wrappers import ( + BatchNorm2d, + Conv2d, + ConvTranspose2d, + cat, + interpolate, + Linear, + nonzero_tuple, + cross_entropy, + empty_input_loss_func_wrapper, + shapes_to_tensor, + move_device_like, +) +from .blocks import CNNBlockBase, DepthwiseSeparableConv2d +from .aspp import ASPP +from .losses import ciou_loss, diou_loss + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/layers/aspp.py b/data_processing/detectron2/detectron2/layers/aspp.py new file mode 100644 index 0000000..14861aa --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/aspp.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from copy import deepcopy +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from .batch_norm import get_norm +from .blocks import DepthwiseSeparableConv2d +from .wrappers import Conv2d + + +class ASPP(nn.Module): + """ + Atrous Spatial Pyramid Pooling (ASPP). + """ + + def __init__( + self, + in_channels, + out_channels, + dilations, + *, + norm, + activation, + pool_kernel_size=None, + dropout: float = 0.0, + use_depthwise_separable_conv=False, + ): + """ + Args: + in_channels (int): number of input channels for ASPP. + out_channels (int): number of output channels. + dilations (list): a list of 3 dilations in ASPP. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. norm is + applied to all conv layers except the conv following + global average pooling. + activation (callable): activation function. + pool_kernel_size (tuple, list): the average pooling size (kh, kw) + for image pooling layer in ASPP. If set to None, it always + performs global average pooling. If not None, it must be + divisible by the shape of inputs in forward(). It is recommended + to use a fixed input feature size in training, and set this + option to match this size, so that it performs global average + pooling in training, and the size of the pooling window stays + consistent in inference. + dropout (float): apply dropout on the output of ASPP. It is used in + the official DeepLab implementation with a rate of 0.1: + https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa + use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d + for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. + """ + super(ASPP, self).__init__() + assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) + self.pool_kernel_size = pool_kernel_size + self.dropout = dropout + use_bias = norm == "" + self.convs = nn.ModuleList() + # conv 1x1 + self.convs.append( + Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + ) + weight_init.c2_xavier_fill(self.convs[-1]) + # atrous convs + for dilation in dilations: + if use_depthwise_separable_conv: + self.convs.append( + DepthwiseSeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm1=norm, + activation1=deepcopy(activation), + norm2=norm, + activation2=deepcopy(activation), + ) + ) + else: + self.convs.append( + Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + ) + weight_init.c2_xavier_fill(self.convs[-1]) + # image pooling + # We do not add BatchNorm because the spatial resolution is 1x1, + # the original TF implementation has BatchNorm. + if pool_kernel_size is None: + image_pooling = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), + ) + else: + image_pooling = nn.Sequential( + nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), + Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), + ) + weight_init.c2_xavier_fill(image_pooling[1]) + self.convs.append(image_pooling) + + self.project = Conv2d( + 5 * out_channels, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + weight_init.c2_xavier_fill(self.project) + + def forward(self, x): + size = x.shape[-2:] + if self.pool_kernel_size is not None: + if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: + raise ValueError( + "`pool_kernel_size` must be divisible by the shape of inputs. " + "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) + ) + res = [] + for conv in self.convs: + res.append(conv(x)) + res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) + res = torch.cat(res, dim=1) + res = self.project(res) + res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res + return res diff --git a/data_processing/detectron2/detectron2/layers/batch_norm.py b/data_processing/detectron2/detectron2/layers/batch_norm.py new file mode 100644 index 0000000..f594587 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/batch_norm.py @@ -0,0 +1,320 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +import torch.distributed as dist +from fvcore.nn.distributed import differentiable_all_reduce +from torch import nn +from torch.nn import functional as F + +from detectron2.utils import comm, env + +from .wrappers import BatchNorm2d + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + It contains non-trainable buffers called + "weight" and "bias", "running_mean", "running_var", + initialized to perform identity transformation. + + The pre-trained backbone models from Caffe2 only contain "weight" and "bias", + which are computed from the original four parameters of BN. + The affine transform `x * weight + bias` will perform the equivalent + computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. + When loading a backbone model from Caffe2, "running_mean" and "running_var" + will be left unchanged as identity transformation. + + Other pre-trained backbone models may contain all 4 parameters. + + The forward is implemented by `F.batch_norm(..., training=False)`. + """ + + _version = 3 + + def __init__(self, num_features, eps=1e-5): + super().__init__() + self.num_features = num_features + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + self.register_buffer("num_batches_tracked", None) + + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + out_dtype = x.dtype # may be half + return x * scale.to(out_dtype) + bias.to(out_dtype) + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # No running_mean/var in early versions + # This will silent the warnings + if prefix + "running_mean" not in state_dict: + state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) + if prefix + "running_var" not in state_dict: + state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. + + Args: + module (torch.nn.Module): + + Returns: + If module is BatchNorm/SyncBatchNorm, returns a new module. + Otherwise, in-place convert module and return it. + + Similar to convert_sync_batchnorm in + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py + """ + bn_module = nn.modules.batchnorm + bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) + res = module + if isinstance(module, bn_module): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.num_batches_tracked = module.num_batches_tracked + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "BN": BatchNorm2d, + # Fixed in https://github.com/pytorch/pytorch/pull/36382 + "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, + "FrozenBN": FrozenBatchNorm2d, + "GN": lambda channels: nn.GroupNorm(32, channels), + # for debugging: + "nnSyncBN": nn.SyncBatchNorm, + "naiveSyncBN": NaiveSyncBatchNorm, + # expose stats_mode N as an option to caller, required for zero-len inputs + "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), + "LN": lambda channels: LayerNorm(channels), + }[norm] + return norm(out_channels) + + +class NaiveSyncBatchNorm(BatchNorm2d): + """ + In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient + when the batch size on each worker is different. + (e.g., when scale augmentation is used, or when it is applied to mask head). + + This is a slower but correct alternative to `nn.SyncBatchNorm`. + + Note: + There isn't a single definition of Sync BatchNorm. + + When ``stats_mode==""``, this module computes overall statistics by using + statistics of each worker with equal weight. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (N, H, W). This mode does not support inputs with zero batch size. + + When ``stats_mode=="N"``, this module computes overall statistics by weighting + the statistics of each worker by their ``N``. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (H, W). It is slower than ``stats_mode==""``. + + Even though the result of this module may not be the true statistics of all samples, + it may still be reasonable because it might be preferrable to assign equal weights + to all workers, regardless of their (H, W) dimension, instead of putting larger weight + on larger images. From preliminary experiments, little difference is found between such + a simplified implementation and an accurate computation of overall mean & variance. + """ + + def __init__(self, *args, stats_mode="", **kwargs): + super().__init__(*args, **kwargs) + assert stats_mode in ["", "N"] + self._stats_mode = stats_mode + + def forward(self, input): + if comm.get_world_size() == 1 or not self.training: + return super().forward(input) + + B, C = input.shape[0], input.shape[1] + + half_input = input.dtype == torch.float16 + if half_input: + # fp16 does not have good enough numerics for the reduction here + input = input.float() + mean = torch.mean(input, dim=[0, 2, 3]) + meansqr = torch.mean(input * input, dim=[0, 2, 3]) + + if self._stats_mode == "": + assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' + vec = torch.cat([mean, meansqr], dim=0) + vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) + mean, meansqr = torch.split(vec, C) + momentum = self.momentum + else: + if B == 0: + vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) + vec = vec + input.sum() # make sure there is gradient w.r.t input + else: + vec = torch.cat( + [ + mean, + meansqr, + torch.ones([1], device=mean.device, dtype=mean.dtype), + ], + dim=0, + ) + vec = differentiable_all_reduce(vec * B) + + total_batch = vec[-1].detach() + momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 + mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero + + var = meansqr - mean * mean + invstd = torch.rsqrt(var + self.eps) + scale = self.weight * invstd + bias = self.bias - mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + + self.running_mean += momentum * (mean.detach() - self.running_mean) + self.running_var += momentum * (var.detach() - self.running_var) + ret = input * scale + bias + if half_input: + ret = ret.half() + return ret + + +class CycleBatchNormList(nn.ModuleList): + """ + Implement domain-specific BatchNorm by cycling. + + When a BatchNorm layer is used for multiple input domains or input + features, it might need to maintain a separate test-time statistics + for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`. + + This module implements it by using N separate BN layers + and it cycles through them every time a forward() is called. + + NOTE: The caller of this module MUST guarantee to always call + this module by multiple of N times. Otherwise its test-time statistics + will be incorrect. + """ + + def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs): + """ + Args: + length: number of BatchNorm layers to cycle. + bn_class: the BatchNorm class to use + kwargs: arguments of the BatchNorm class, such as num_features. + """ + self._affine = kwargs.pop("affine", True) + super().__init__([bn_class(**kwargs, affine=False) for k in range(length)]) + if self._affine: + # shared affine, domain-specific BN + channels = self[0].num_features + self.weight = nn.Parameter(torch.ones(channels)) + self.bias = nn.Parameter(torch.zeros(channels)) + self._pos = 0 + + def forward(self, x): + ret = self[self._pos](x) + self._pos = (self._pos + 1) % len(self) + + if self._affine: + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + return ret * w + b + else: + return ret + + def extra_repr(self): + return f"affine={self._affine}" + + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/data_processing/detectron2/detectron2/layers/blocks.py b/data_processing/detectron2/detectron2/layers/blocks.py new file mode 100644 index 0000000..1995a4b --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/blocks.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import fvcore.nn.weight_init as weight_init +from torch import nn + +from .batch_norm import FrozenBatchNorm2d, get_norm +from .wrappers import Conv2d + + +""" +CNN building blocks. +""" + + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NCHW tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + def freeze(self): + """ + Make this block not trainable. + This method sets all parameters to `requires_grad=False`, + and convert all BatchNorm layers to FrozenBatchNorm + + Returns: + the block itself + """ + for p in self.parameters(): + p.requires_grad = False + FrozenBatchNorm2d.convert_frozen_batchnorm(self) + return self + + +class DepthwiseSeparableConv2d(nn.Module): + """ + A kxk depthwise convolution + a 1x1 convolution. + + In :paper:`xception`, norm & activation are applied on the second conv. + :paper:`mobilenet` uses norm & activation on both convs. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + padding=1, + dilation=1, + *, + norm1=None, + activation1=None, + norm2=None, + activation2=None, + ): + """ + Args: + norm1, norm2 (str or callable): normalization for the two conv layers. + activation1, activation2 (callable(Tensor) -> Tensor): activation + function for the two conv layers. + """ + super().__init__() + self.depthwise = Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=not norm1, + norm=get_norm(norm1, in_channels), + activation=activation1, + ) + self.pointwise = Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=not norm2, + norm=get_norm(norm2, out_channels), + activation=activation2, + ) + + # default initialization + weight_init.c2_msra_fill(self.depthwise) + weight_init.c2_msra_fill(self.pointwise) + + def forward(self, x): + return self.pointwise(self.depthwise(x)) diff --git a/data_processing/detectron2/detectron2/layers/csrc/README.md b/data_processing/detectron2/detectron2/layers/csrc/README.md new file mode 100644 index 0000000..778ed3d --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/README.md @@ -0,0 +1,7 @@ + + +To add a new Op: + +1. Create a new directory +2. Implement new ops there +3. Delcare its Python interface in `vision.cpp`. diff --git a/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h new file mode 100644 index 0000000..03f4211 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h @@ -0,0 +1,115 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor ROIAlignRotated_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlignRotated_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor ROIAlignRotated_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlignRotated_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); +#endif + +// Interface for Python +inline at::Tensor ROIAlignRotated_forward( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return ROIAlignRotated_forward_cuda( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + return ROIAlignRotated_forward_cpu( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +inline at::Tensor ROIAlignRotated_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio) { + if (grad.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return ROIAlignRotated_backward_cuda( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + return ROIAlignRotated_backward_cpu( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp new file mode 100644 index 0000000..2a3d305 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp @@ -0,0 +1,522 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include "ROIAlignRotated.h" + +// Note: this implementation originates from the Caffe2 ROIAlignRotated Op +// and PyTorch ROIAlign (non-rotated) Op implementations. +// The key difference between this implementation and those ones is +// we don't do "legacy offset" in this version, as there aren't many previous +// works, if any, using the "legacy" ROIAlignRotated Op. +// This would make the interface a bit cleaner. + +namespace detectron2 { + +namespace { +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + T roi_center_h, + T roi_center_w, + T cos_theta, + T sin_theta, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + // In image space, (y, x) is the order for Right Handed System, + // and this is essentially multiplying the point by a rotation matrix + // to rotate it counterclockwise through angle theta. + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y < 0) { + y = 0; + } + if (x < 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +} // namespace + +template +void ROIAlignRotatedForward( + const int nthreads, + const T* input, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + T* output) { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + AT_ASSERTM( + roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlignRotated do not have non-negative size!"); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + roi_center_h, + roi_center_w, + cos_theta, + sin_theta, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void ROIAlignRotatedBackward( + const int nthreads, + // may not be contiguous. should index using n_stride, etc + const T* grad_output, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* grad_input, + const T* rois, + const int n_stride, + const int c_stride, + const int h_stride, + const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + AT_ASSERTM( + roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlignRotated do not have non-negative size!"); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} // ROIAlignRotatedBackward + +at::Tensor ROIAlignRotated_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_forward_cpu"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ROIAlignRotated_forward", [&] { + ROIAlignRotatedForward( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + output.data_ptr()); + }); + return output; +} + +at::Tensor ROIAlignRotated_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignRotated_backward_cpu"; + at::checkAllSameType(c, {grad_t, rois_t}); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ROIAlignRotated_forward", [&] { + ROIAlignRotatedBackward( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu new file mode 100644 index 0000000..fca1865 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu @@ -0,0 +1,443 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +// Note: this implementation originates from the Caffe2 ROIAlignRotated Op +// and PyTorch ROIAlign (non-rotated) Op implementations. +// The key difference between this implementation and those ones is +// we don't do "legacy offset" in this version, as there aren't many previous +// works, if any, using the "legacy" ROIAlignRotated Op. +// This would make the interface a bit cleaner. + +namespace detectron2 { + +namespace { + +template +__device__ T bilinear_interpolate( + const T* input, + const int height, + const int width, + T y, + T x) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +} // namespace + +template +__global__ void RoIAlignRotatedForward( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (inte gral) pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T val = bilinear_interpolate(offset_input, height, width, y, x); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + +template +__global__ void RoIAlignRotatedBackwardFeature( + const int nthreads, + const T* top_diff, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* bottom_diff, + const T* rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_bottom_diff = + bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd( + offset_bottom_diff + y_low * width + x_low, static_cast(g1)); + atomicAdd( + offset_bottom_diff + y_low * width + x_high, static_cast(g2)); + atomicAdd( + offset_bottom_diff + y_high * width + x_low, static_cast(g3)); + atomicAdd( + offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignRotatedBackward + +at::Tensor ROIAlignRotated_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignRotated_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(output_size), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + if (output.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "ROIAlignRotated_forward", [&] { + RoIAlignRotatedForward<<>>( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + output.data_ptr()); + }); + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + return output; +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIAlignRotated_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + at::CheckedFrom c = "ROIAlign_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + at::cuda::CUDAGuard device_guard(grad.device()); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(grad.numel()), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES( + grad.scalar_type(), "ROIAlignRotated_backward", [&] { + RoIAlignRotatedBackwardFeature<<>>( + grad.numel(), + grad_.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h new file mode 100644 index 0000000..3bf383b --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor box_iou_rotated_cpu( + const at::Tensor& boxes1, + const at::Tensor& boxes2); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor box_iou_rotated_cuda( + const at::Tensor& boxes1, + const at::Tensor& boxes2); +#endif + +// Interface for Python +// inline is needed to prevent multiple function definitions when this header is +// included by different cpps +inline at::Tensor box_iou_rotated( + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + assert(boxes1.device().is_cuda() == boxes2.device().is_cuda()); + if (boxes1.device().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return box_iou_rotated_cuda(boxes1.contiguous(), boxes2.contiguous()); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + + return box_iou_rotated_cpu(boxes1.contiguous(), boxes2.contiguous()); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp new file mode 100644 index 0000000..c843487 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "box_iou_rotated.h" +#include "box_iou_rotated_utils.h" + +namespace detectron2 { + +template +void box_iou_rotated_cpu_kernel( + const at::Tensor& boxes1, + const at::Tensor& boxes2, + at::Tensor& ious) { + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + + for (int i = 0; i < num_boxes1; i++) { + for (int j = 0; j < num_boxes2; j++) { + ious[i * num_boxes2 + j] = single_box_iou_rotated( + boxes1[i].data_ptr(), boxes2[j].data_ptr()); + } + } +} + +at::Tensor box_iou_rotated_cpu( + // input must be contiguous: + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + at::Tensor ious = + at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); + + box_iou_rotated_cpu_kernel(boxes1, boxes2, ious); + + // reshape from 1d array to 2d array + auto shape = std::vector{num_boxes1, num_boxes2}; + return ious.reshape(shape); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu new file mode 100644 index 0000000..952710e --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu @@ -0,0 +1,130 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include +#include "box_iou_rotated_utils.h" + +namespace detectron2 { + +// 2D block with 32 * 16 = 512 threads per block +const int BLOCK_DIM_X = 32; +const int BLOCK_DIM_Y = 16; + +template +__global__ void box_iou_rotated_cuda_kernel( + const int n_boxes1, + const int n_boxes2, + const T* dev_boxes1, + const T* dev_boxes2, + T* dev_ious) { + const int row_start = blockIdx.x * blockDim.x; + const int col_start = blockIdx.y * blockDim.y; + + const int row_size = min(n_boxes1 - row_start, blockDim.x); + const int col_size = min(n_boxes2 - col_start, blockDim.y); + + __shared__ float block_boxes1[BLOCK_DIM_X * 5]; + __shared__ float block_boxes2[BLOCK_DIM_Y * 5]; + + // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y + if (threadIdx.x < row_size && threadIdx.y == 0) { + block_boxes1[threadIdx.x * 5 + 0] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 0]; + block_boxes1[threadIdx.x * 5 + 1] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 1]; + block_boxes1[threadIdx.x * 5 + 2] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 2]; + block_boxes1[threadIdx.x * 5 + 3] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 3]; + block_boxes1[threadIdx.x * 5 + 4] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 4]; + } + + if (threadIdx.x < col_size && threadIdx.y == 0) { + block_boxes2[threadIdx.x * 5 + 0] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 0]; + block_boxes2[threadIdx.x * 5 + 1] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 1]; + block_boxes2[threadIdx.x * 5 + 2] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 2]; + block_boxes2[threadIdx.x * 5 + 3] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 3]; + block_boxes2[threadIdx.x * 5 + 4] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size && threadIdx.y < col_size) { + int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y; + dev_ious[offset] = single_box_iou_rotated( + block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); + } +} + +at::Tensor box_iou_rotated_cuda( + // input must be contiguous + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + using scalar_t = float; + AT_ASSERTM( + boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor"); + AT_ASSERTM( + boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor"); + AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor"); + AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor"); + at::cuda::CUDAGuard device_guard(boxes1.device()); + + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + + at::Tensor ious = + at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); + + bool transpose = false; + if (num_boxes1 > 0 && num_boxes2 > 0) { + scalar_t *data1 = boxes1.data_ptr(), + *data2 = boxes2.data_ptr(); + + if (num_boxes2 > 65535 * BLOCK_DIM_Y) { + AT_ASSERTM( + num_boxes1 <= 65535 * BLOCK_DIM_Y, + "Too many boxes for box_iou_rotated_cuda!"); + // x dim is allowed to be large, but y dim cannot, + // so we transpose the two to avoid "invalid configuration argument" + // error. We assume one of them is small. Otherwise the result is hard to + // fit in memory anyway. + std::swap(num_boxes1, num_boxes2); + std::swap(data1, data2); + transpose = true; + } + + const int blocks_x = + at::cuda::ATenCeilDiv(static_cast(num_boxes1), BLOCK_DIM_X); + const int blocks_y = + at::cuda::ATenCeilDiv(static_cast(num_boxes2), BLOCK_DIM_Y); + + dim3 blocks(blocks_x, blocks_y); + dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + box_iou_rotated_cuda_kernel<<>>( + num_boxes1, + num_boxes2, + data1, + data2, + (scalar_t*)ious.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); + } + + // reshape from 1d array to 2d array + auto shape = std::vector{num_boxes1, num_boxes2}; + if (transpose) { + return ious.view(shape).t(); + } else { + return ious.view(shape); + } +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h new file mode 100644 index 0000000..b54a5dd --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h @@ -0,0 +1,370 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once + +#include +#include + +#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1 +// Designates functions callable from the host (CPU) and the device (GPU) +#define HOST_DEVICE __host__ __device__ +#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__ +#else +#include +#define HOST_DEVICE +#define HOST_DEVICE_INLINE HOST_DEVICE inline +#endif + +namespace detectron2 { + +namespace { + +template +struct RotatedBox { + T x_ctr, y_ctr, w, h, a; +}; + +template +struct Point { + T x, y; + HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {} + HOST_DEVICE_INLINE Point operator+(const Point& p) const { + return Point(x + p.x, y + p.y); + } + HOST_DEVICE_INLINE Point& operator+=(const Point& p) { + x += p.x; + y += p.y; + return *this; + } + HOST_DEVICE_INLINE Point operator-(const Point& p) const { + return Point(x - p.x, y - p.y); + } + HOST_DEVICE_INLINE Point operator*(const T coeff) const { + return Point(x * coeff, y * coeff); + } +}; + +template +HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) { + return A.x * B.x + A.y * B.y; +} + +// R: result type. can be different from input type +template +HOST_DEVICE_INLINE R cross_2d(const Point& A, const Point& B) { + return static_cast(A.x) * static_cast(B.y) - + static_cast(B.x) * static_cast(A.y); +} + +template +HOST_DEVICE_INLINE void get_rotated_vertices( + const RotatedBox& box, + Point (&pts)[4]) { + // M_PI / 180. == 0.01745329251 + double theta = box.a * 0.01745329251; + T cosTheta2 = (T)cos(theta) * 0.5f; + T sinTheta2 = (T)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; +} + +template +HOST_DEVICE_INLINE int get_intersection_points( + const Point (&pts1)[4], + const Point (&pts2)[4], + Point (&intersections)[24]) { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // When computing the intersection area, it doesn't hurt if we have + // more (duplicated/approximate) intersections/vertices than needed, + // while it can cause drastic difference if we miss an intersection/vertex. + // Therefore, we add an epsilon to relax the comparisons between + // the float point numbers that decide the intersection points. + double EPS = 1e-5; + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Solve for 2x2 Ax=b + T det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + T t1 = cross_2d(vec2[j], vec12) / det; + T t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 > -EPS && t1 < 1.0f + EPS && t2 > -EPS && t2 < 1.0f + EPS) { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) && + (APdotAD < ADdotAD + EPS)) { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) && + (APdotAD < ADdotAD + EPS)) { + intersections[num++] = pts2[i]; + } + } + } + + return num; +} + +template +HOST_DEVICE_INLINE int convex_hull_graham( + const Point (&p)[24], + const int& num_in, + Point (&q)[24], + bool shift_to_zero = false) { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + T dist[24]; +#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1 + // compute distance to origin before sort, and sort them together with the + // points + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } + + // CUDA version + // In the future, we can potentially use thrust + // for sorting here to improve speed (though not guaranteed) + for (int i = 1; i < num_in - 1; i++) { + for (int j = i + 1; j < num_in; j++) { + T crossProduct = cross_2d(q[i], q[j]); + if ((crossProduct < -1e-6) || + (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { + auto q_tmp = q[i]; + q[i] = q[j]; + q[j] = q_tmp; + auto dist_tmp = dist[i]; + dist[i] = dist[j]; + dist[j] = dist_tmp; + } + } + } +#else + // CPU version + std::sort( + q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + T temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return dot_2d(A, A) < dot_2d(B, B); + } else { + return temp > 0; + } + }); + // compute distance to origin after sort, since the points are now different. + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } +#endif + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) { + if (dist[k] > 1e-8) { + break; + } + } + if (k == num_in) { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) { + while (m > 1) { + auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2]; + // cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) - + // q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we + // compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means + // round to nearest floating point). + if (q1.x * q2.y >= q2.x * q1.y) + m--; + else + break; + } + // Using double also helps, but float can solve the issue for now. + // while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) + // >= 0) { + // m--; + // } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) { + for (int i = 0; i < m; i++) { + q[i] += start; + } + } + + return m; +} + +template +HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) { + if (m <= 2) { + return 0; + } + + T area = 0; + for (int i = 1; i < m - 1; i++) { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } + + return area / 2.0; +} + +template +HOST_DEVICE_INLINE T rotated_boxes_intersection( + const RotatedBox& box1, + const RotatedBox& box2) { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; + + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); + + int num = get_intersection_points(pts1, pts2, intersectPts); + + if (num <= 2) { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); +} + +} // namespace + +template +HOST_DEVICE_INLINE T +single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) { + // shift center to the middle point to achieve higher precision in result + RotatedBox box1, box2; + auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; + auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; + box1.x_ctr = box1_raw[0] - center_shift_x; + box1.y_ctr = box1_raw[1] - center_shift_y; + box1.w = box1_raw[2]; + box1.h = box1_raw[3]; + box1.a = box1_raw[4]; + box2.x_ctr = box2_raw[0] - center_shift_x; + box2.y_ctr = box2_raw[1] - center_shift_y; + box2.w = box2_raw[2]; + box2.h = box2_raw[3]; + box2.a = box2_raw[4]; + + T area1 = box1.w * box1.h; + T area2 = box2.w * box2.h; + if (area1 < 1e-14 || area2 < 1e-14) { + return 0.f; + } + + T intersection = rotated_boxes_intersection(box1, box2); + T iou = intersection / (area1 + area2 - intersection); + return iou; +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.cpp b/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.cpp new file mode 100644 index 0000000..0a5b7b9 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.cpp @@ -0,0 +1,507 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "cocoeval.h" +#include +#include +#include +#include + +using namespace pybind11::literals; + +namespace detectron2 { + +namespace COCOeval { + +// Sort detections from highest score to lowest, such that +// detection_instances[detection_sorted_indices[t]] >= +// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match +// original COCO API +void SortInstancesByDetectionScore( + const std::vector& detection_instances, + std::vector* detection_sorted_indices) { + detection_sorted_indices->resize(detection_instances.size()); + std::iota( + detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); + std::stable_sort( + detection_sorted_indices->begin(), + detection_sorted_indices->end(), + [&detection_instances](size_t j1, size_t j2) { + return detection_instances[j1].score > detection_instances[j2].score; + }); +} + +// Partition the ground truth objects based on whether or not to ignore them +// based on area +void SortInstancesByIgnore( + const std::array& area_range, + const std::vector& ground_truth_instances, + std::vector* ground_truth_sorted_indices, + std::vector* ignores) { + ignores->clear(); + ignores->reserve(ground_truth_instances.size()); + for (auto o : ground_truth_instances) { + ignores->push_back( + o.ignore || o.area < area_range[0] || o.area > area_range[1]); + } + + ground_truth_sorted_indices->resize(ground_truth_instances.size()); + std::iota( + ground_truth_sorted_indices->begin(), + ground_truth_sorted_indices->end(), + 0); + std::stable_sort( + ground_truth_sorted_indices->begin(), + ground_truth_sorted_indices->end(), + [&ignores](size_t j1, size_t j2) { + return (int)(*ignores)[j1] < (int)(*ignores)[j2]; + }); +} + +// For each IOU threshold, greedily match each detected instance to a ground +// truth instance (if possible) and store the results +void MatchDetectionsToGroundTruth( + const std::vector& detection_instances, + const std::vector& detection_sorted_indices, + const std::vector& ground_truth_instances, + const std::vector& ground_truth_sorted_indices, + const std::vector& ignores, + const std::vector>& ious, + const std::vector& iou_thresholds, + const std::array& area_range, + ImageEvaluation* results) { + // Initialize memory to store return data matches and ignore + const int num_iou_thresholds = iou_thresholds.size(); + const int num_ground_truth = ground_truth_sorted_indices.size(); + const int num_detections = detection_sorted_indices.size(); + std::vector ground_truth_matches( + num_iou_thresholds * num_ground_truth, 0); + std::vector& detection_matches = results->detection_matches; + std::vector& detection_ignores = results->detection_ignores; + std::vector& ground_truth_ignores = results->ground_truth_ignores; + detection_matches.resize(num_iou_thresholds * num_detections, 0); + detection_ignores.resize(num_iou_thresholds * num_detections, false); + ground_truth_ignores.resize(num_ground_truth); + for (auto g = 0; g < num_ground_truth; ++g) { + ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]]; + } + + for (auto t = 0; t < num_iou_thresholds; ++t) { + for (auto d = 0; d < num_detections; ++d) { + // information about best match so far (match=-1 -> unmatched) + double best_iou = std::min(iou_thresholds[t], 1 - 1e-10); + int match = -1; + for (auto g = 0; g < num_ground_truth; ++g) { + // if this ground truth instance is already matched and not a + // crowd, it cannot be matched to another detection + if (ground_truth_matches[t * num_ground_truth + g] > 0 && + !ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) { + continue; + } + + // if detected instance matched to a regular ground truth + // instance, we can break on the first ground truth instance + // tagged as ignore (because they are sorted by the ignore tag) + if (match >= 0 && !ground_truth_ignores[match] && + ground_truth_ignores[g]) { + break; + } + + // if IOU overlap is the best so far, store the match appropriately + if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) { + best_iou = ious[d][ground_truth_sorted_indices[g]]; + match = g; + } + } + // if match was made, store id of match for both detection and + // ground truth + if (match >= 0) { + detection_ignores[t * num_detections + d] = ground_truth_ignores[match]; + detection_matches[t * num_detections + d] = + ground_truth_instances[ground_truth_sorted_indices[match]].id; + ground_truth_matches[t * num_ground_truth + match] = + detection_instances[detection_sorted_indices[d]].id; + } + + // set unmatched detections outside of area range to ignore + const InstanceAnnotation& detection = + detection_instances[detection_sorted_indices[d]]; + detection_ignores[t * num_detections + d] = + detection_ignores[t * num_detections + d] || + (detection_matches[t * num_detections + d] == 0 && + (detection.area < area_range[0] || detection.area > area_range[1])); + } + } + + // store detection score results + results->detection_scores.resize(detection_sorted_indices.size()); + for (size_t d = 0; d < detection_sorted_indices.size(); ++d) { + results->detection_scores[d] = + detection_instances[detection_sorted_indices[d]].score; + } +} + +std::vector EvaluateImages( + const std::vector>& area_ranges, + int max_detections, + const std::vector& iou_thresholds, + const ImageCategoryInstances>& image_category_ious, + const ImageCategoryInstances& + image_category_ground_truth_instances, + const ImageCategoryInstances& + image_category_detection_instances) { + const int num_area_ranges = area_ranges.size(); + const int num_images = image_category_ground_truth_instances.size(); + const int num_categories = + image_category_ious.size() > 0 ? image_category_ious[0].size() : 0; + std::vector detection_sorted_indices; + std::vector ground_truth_sorted_indices; + std::vector ignores; + std::vector results_all( + num_images * num_area_ranges * num_categories); + + // Store results for each image, category, and area range combination. Results + // for each IOU threshold are packed into the same ImageEvaluation object + for (auto i = 0; i < num_images; ++i) { + for (auto c = 0; c < num_categories; ++c) { + const std::vector& ground_truth_instances = + image_category_ground_truth_instances[i][c]; + const std::vector& detection_instances = + image_category_detection_instances[i][c]; + + SortInstancesByDetectionScore( + detection_instances, &detection_sorted_indices); + if ((int)detection_sorted_indices.size() > max_detections) { + detection_sorted_indices.resize(max_detections); + } + + for (size_t a = 0; a < area_ranges.size(); ++a) { + SortInstancesByIgnore( + area_ranges[a], + ground_truth_instances, + &ground_truth_sorted_indices, + &ignores); + + MatchDetectionsToGroundTruth( + detection_instances, + detection_sorted_indices, + ground_truth_instances, + ground_truth_sorted_indices, + ignores, + image_category_ious[i][c], + iou_thresholds, + area_ranges[a], + &results_all + [c * num_area_ranges * num_images + a * num_images + i]); + } + } + } + + return results_all; +} + +// Convert a python list to a vector +template +std::vector list_to_vec(const py::list& l) { + std::vector v(py::len(l)); + for (int i = 0; i < (int)py::len(l); ++i) { + v[i] = l[i].cast(); + } + return v; +} + +// Helper function to Accumulate() +// Considers the evaluation results applicable to a particular category, area +// range, and max_detections parameter setting, which begin at +// evaluations[evaluation_index]. Extracts a sorted list of length n of all +// applicable detection instances concatenated across all images in the dataset, +// which are represented by the outputs evaluation_indices, detection_scores, +// image_detection_indices, and detection_sorted_indices--all of which are +// length n. evaluation_indices[i] stores the applicable index into +// evaluations[] for instance i, which has detection score detection_score[i], +// and is the image_detection_indices[i]'th of the list of detections +// for the image containing i. detection_sorted_indices[] defines a sorted +// permutation of the 3 other outputs +int BuildSortedDetectionList( + const std::vector& evaluations, + const int64_t evaluation_index, + const int64_t num_images, + const int max_detections, + std::vector* evaluation_indices, + std::vector* detection_scores, + std::vector* detection_sorted_indices, + std::vector* image_detection_indices) { + assert(evaluations.size() >= evaluation_index + num_images); + + // Extract a list of object instances of the applicable category, area + // range, and max detections requirements such that they can be sorted + image_detection_indices->clear(); + evaluation_indices->clear(); + detection_scores->clear(); + image_detection_indices->reserve(num_images * max_detections); + evaluation_indices->reserve(num_images * max_detections); + detection_scores->reserve(num_images * max_detections); + int num_valid_ground_truth = 0; + for (auto i = 0; i < num_images; ++i) { + const ImageEvaluation& evaluation = evaluations[evaluation_index + i]; + + for (int d = 0; + d < (int)evaluation.detection_scores.size() && d < max_detections; + ++d) { // detected instances + evaluation_indices->push_back(evaluation_index + i); + image_detection_indices->push_back(d); + detection_scores->push_back(evaluation.detection_scores[d]); + } + for (auto ground_truth_ignore : evaluation.ground_truth_ignores) { + if (!ground_truth_ignore) { + ++num_valid_ground_truth; + } + } + } + + // Sort detections by decreasing score, using stable sort to match + // python implementation + detection_sorted_indices->resize(detection_scores->size()); + std::iota( + detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); + std::stable_sort( + detection_sorted_indices->begin(), + detection_sorted_indices->end(), + [&detection_scores](size_t j1, size_t j2) { + return (*detection_scores)[j1] > (*detection_scores)[j2]; + }); + + return num_valid_ground_truth; +} + +// Helper function to Accumulate() +// Compute a precision recall curve given a sorted list of detected instances +// encoded in evaluations, evaluation_indices, detection_scores, +// detection_sorted_indices, image_detection_indices (see +// BuildSortedDetectionList()). Using vectors precisions and recalls +// and temporary storage, output the results into precisions_out, recalls_out, +// and scores_out, which are large buffers containing many precion/recall curves +// for all possible parameter settings, with precisions_out_index and +// recalls_out_index defining the applicable indices to store results. +void ComputePrecisionRecallCurve( + const int64_t precisions_out_index, + const int64_t precisions_out_stride, + const int64_t recalls_out_index, + const std::vector& recall_thresholds, + const int iou_threshold_index, + const int num_iou_thresholds, + const int num_valid_ground_truth, + const std::vector& evaluations, + const std::vector& evaluation_indices, + const std::vector& detection_scores, + const std::vector& detection_sorted_indices, + const std::vector& image_detection_indices, + std::vector* precisions, + std::vector* recalls, + std::vector* precisions_out, + std::vector* scores_out, + std::vector* recalls_out) { + assert(recalls_out->size() > recalls_out_index); + + // Compute precision/recall for each instance in the sorted list of detections + int64_t true_positives_sum = 0, false_positives_sum = 0; + precisions->clear(); + recalls->clear(); + precisions->reserve(detection_sorted_indices.size()); + recalls->reserve(detection_sorted_indices.size()); + assert(!evaluations.empty() || detection_sorted_indices.empty()); + for (auto detection_sorted_index : detection_sorted_indices) { + const ImageEvaluation& evaluation = + evaluations[evaluation_indices[detection_sorted_index]]; + const auto num_detections = + evaluation.detection_matches.size() / num_iou_thresholds; + const auto detection_index = iou_threshold_index * num_detections + + image_detection_indices[detection_sorted_index]; + assert(evaluation.detection_matches.size() > detection_index); + assert(evaluation.detection_ignores.size() > detection_index); + const int64_t detection_match = + evaluation.detection_matches[detection_index]; + const bool detection_ignores = + evaluation.detection_ignores[detection_index]; + const auto true_positive = detection_match > 0 && !detection_ignores; + const auto false_positive = detection_match == 0 && !detection_ignores; + if (true_positive) { + ++true_positives_sum; + } + if (false_positive) { + ++false_positives_sum; + } + + const double recall = + static_cast(true_positives_sum) / num_valid_ground_truth; + recalls->push_back(recall); + const int64_t num_valid_detections = + true_positives_sum + false_positives_sum; + const double precision = num_valid_detections > 0 + ? static_cast(true_positives_sum) / num_valid_detections + : 0.0; + precisions->push_back(precision); + } + + (*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0; + + for (int64_t i = static_cast(precisions->size()) - 1; i > 0; --i) { + if ((*precisions)[i] > (*precisions)[i - 1]) { + (*precisions)[i - 1] = (*precisions)[i]; + } + } + + // Sample the per instance precision/recall list at each recall threshold + for (size_t r = 0; r < recall_thresholds.size(); ++r) { + // first index in recalls >= recall_thresholds[r] + std::vector::iterator low = std::lower_bound( + recalls->begin(), recalls->end(), recall_thresholds[r]); + size_t precisions_index = low - recalls->begin(); + + const auto results_ind = precisions_out_index + r * precisions_out_stride; + assert(results_ind < precisions_out->size()); + assert(results_ind < scores_out->size()); + if (precisions_index < precisions->size()) { + (*precisions_out)[results_ind] = (*precisions)[precisions_index]; + (*scores_out)[results_ind] = + detection_scores[detection_sorted_indices[precisions_index]]; + } else { + (*precisions_out)[results_ind] = 0; + (*scores_out)[results_ind] = 0; + } + } +} +py::dict Accumulate( + const py::object& params, + const std::vector& evaluations) { + const std::vector recall_thresholds = + list_to_vec(params.attr("recThrs")); + const std::vector max_detections = + list_to_vec(params.attr("maxDets")); + const int num_iou_thresholds = py::len(params.attr("iouThrs")); + const int num_recall_thresholds = py::len(params.attr("recThrs")); + const int num_categories = params.attr("useCats").cast() == 1 + ? py::len(params.attr("catIds")) + : 1; + const int num_area_ranges = py::len(params.attr("areaRng")); + const int num_max_detections = py::len(params.attr("maxDets")); + const int num_images = py::len(params.attr("imgIds")); + + std::vector precisions_out( + num_iou_thresholds * num_recall_thresholds * num_categories * + num_area_ranges * num_max_detections, + -1); + std::vector recalls_out( + num_iou_thresholds * num_categories * num_area_ranges * + num_max_detections, + -1); + std::vector scores_out( + num_iou_thresholds * num_recall_thresholds * num_categories * + num_area_ranges * num_max_detections, + -1); + + // Consider the list of all detected instances in the entire dataset in one + // large list. evaluation_indices, detection_scores, + // image_detection_indices, and detection_sorted_indices all have the same + // length as this list, such that each entry corresponds to one detected + // instance + std::vector evaluation_indices; // indices into evaluations[] + std::vector detection_scores; // detection scores of each instance + std::vector detection_sorted_indices; // sorted indices of all + // instances in the dataset + std::vector + image_detection_indices; // indices into the list of detected instances in + // the same image as each instance + std::vector precisions, recalls; + + for (auto c = 0; c < num_categories; ++c) { + for (auto a = 0; a < num_area_ranges; ++a) { + for (auto m = 0; m < num_max_detections; ++m) { + // The COCO PythonAPI assumes evaluations[] (the return value of + // COCOeval::EvaluateImages() is one long list storing results for each + // combination of category, area range, and image id, with categories in + // the outermost loop and images in the innermost loop. + const int64_t evaluations_index = + c * num_area_ranges * num_images + a * num_images; + int num_valid_ground_truth = BuildSortedDetectionList( + evaluations, + evaluations_index, + num_images, + max_detections[m], + &evaluation_indices, + &detection_scores, + &detection_sorted_indices, + &image_detection_indices); + + if (num_valid_ground_truth == 0) { + continue; + } + + for (auto t = 0; t < num_iou_thresholds; ++t) { + // recalls_out is a flattened vectors representing a + // num_iou_thresholds X num_categories X num_area_ranges X + // num_max_detections matrix + const int64_t recalls_out_index = + t * num_categories * num_area_ranges * num_max_detections + + c * num_area_ranges * num_max_detections + + a * num_max_detections + m; + + // precisions_out and scores_out are flattened vectors + // representing a num_iou_thresholds X num_recall_thresholds X + // num_categories X num_area_ranges X num_max_detections matrix + const int64_t precisions_out_stride = + num_categories * num_area_ranges * num_max_detections; + const int64_t precisions_out_index = t * num_recall_thresholds * + num_categories * num_area_ranges * num_max_detections + + c * num_area_ranges * num_max_detections + + a * num_max_detections + m; + + ComputePrecisionRecallCurve( + precisions_out_index, + precisions_out_stride, + recalls_out_index, + recall_thresholds, + t, + num_iou_thresholds, + num_valid_ground_truth, + evaluations, + evaluation_indices, + detection_scores, + detection_sorted_indices, + image_detection_indices, + &precisions, + &recalls, + &precisions_out, + &scores_out, + &recalls_out); + } + } + } + } + + time_t rawtime; + struct tm local_time; + std::array buffer; + time(&rawtime); +#ifdef _WIN32 + localtime_s(&local_time, &rawtime); +#else + localtime_r(&rawtime, &local_time); +#endif + strftime( + buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time); + return py::dict( + "params"_a = params, + "counts"_a = std::vector( + {num_iou_thresholds, + num_recall_thresholds, + num_categories, + num_area_ranges, + num_max_detections}), + "date"_a = buffer, + "precision"_a = precisions_out, + "recall"_a = recalls_out, + "scores"_a = scores_out); +} + +} // namespace COCOeval + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.h b/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.h new file mode 100644 index 0000000..db246e4 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/cocoeval/cocoeval.h @@ -0,0 +1,88 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once + +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace detectron2 { + +namespace COCOeval { + +// Annotation data for a single object instance in an image +struct InstanceAnnotation { + InstanceAnnotation( + uint64_t id, + double score, + double area, + bool is_crowd, + bool ignore) + : id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {} + uint64_t id; + double score = 0.; + double area = 0.; + bool is_crowd = false; + bool ignore = false; +}; + +// Stores intermediate results for evaluating detection results for a single +// image that has D detected instances and G ground truth instances. This stores +// matches between detected and ground truth instances +struct ImageEvaluation { + // For each of the D detected instances, the id of the matched ground truth + // instance, or 0 if unmatched + std::vector detection_matches; + + // The detection score of each of the D detected instances + std::vector detection_scores; + + // Marks whether or not each of G instances was ignored from evaluation (e.g., + // because it's outside area_range) + std::vector ground_truth_ignores; + + // Marks whether or not each of D instances was ignored from evaluation (e.g., + // because it's outside aRng) + std::vector detection_ignores; +}; + +template +using ImageCategoryInstances = std::vector>>; + +// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each +// combination of image, category, area range settings, and IOU thresholds to +// evaluate, it matches detected instances to ground truth instances and stores +// the results into a vector of ImageEvaluation results, which will be +// interpreted by the COCOeval::Accumulate() function to produce precion-recall +// curves. The parameters of nested vectors have the following semantics: +// image_category_ious[i][c][d][g] is the intersection over union of the d'th +// detected instance and g'th ground truth instance of +// category category_ids[c] in image image_ids[i] +// image_category_ground_truth_instances[i][c] is a vector of ground truth +// instances in image image_ids[i] of category category_ids[c] +// image_category_detection_instances[i][c] is a vector of detected +// instances in image image_ids[i] of category category_ids[c] +std::vector EvaluateImages( + const std::vector>& area_ranges, // vector of 2-tuples + int max_detections, + const std::vector& iou_thresholds, + const ImageCategoryInstances>& image_category_ious, + const ImageCategoryInstances& + image_category_ground_truth_instances, + const ImageCategoryInstances& + image_category_detection_instances); + +// C++ implementation of COCOeval.accumulate(), which generates precision +// recall curves for each set of category, IOU threshold, detection area range, +// and max number of detections parameters. It is assumed that the parameter +// evaluations is the return value of the functon COCOeval::EvaluateImages(), +// which was called with the same parameter settings params +py::dict Accumulate( + const py::object& params, + const std::vector& evalutations); + +} // namespace COCOeval +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/cuda_version.cu b/data_processing/detectron2/detectron2/layers/csrc/cuda_version.cu new file mode 100644 index 0000000..6dfe1b9 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/cuda_version.cu @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include + +namespace detectron2 { +int get_cudart_version() { +// Not a ROCM platform: Either HIP is not used, or +// it is used, but platform is not ROCM (i.e. it is CUDA) +#if !defined(__HIP_PLATFORM_HCC__) + return CUDART_VERSION; +#else + int version = 0; + +#if HIP_VERSION_MAJOR != 0 + // Create a convention similar to that of CUDA, as assumed by other + // parts of the code. + + version = HIP_VERSION_MINOR; + version += (HIP_VERSION_MAJOR * 100); +#else + hipRuntimeGetVersion(&version); +#endif + return version; +#endif +} +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv.h b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv.h new file mode 100644 index 0000000..965c1bf --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv.h @@ -0,0 +1,377 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +#if defined(WITH_CUDA) || defined(WITH_HIP) +int deform_conv_forward_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step); + +int deform_conv_backward_input_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias); + +#endif + +inline int deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_forward_cuda( + input, + weight, + offset, + output, + columns, + ones, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline int deform_conv_backward_input( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + if (gradOutput.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_backward_input_cuda( + input, + offset, + gradOutput, + gradInput, + gradOffset, + weight, + columns, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline int deform_conv_backward_filter( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) { + if (gradOutput.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_backward_parameters_cuda( + input, + offset, + gradOutput, + gradWeight, + columns, + ones, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + scale, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline void modulated_deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return modulated_deform_conv_cuda_forward( + input, + weight, + bias, + ones, + offset, + mask, + output, + columns, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group, + with_bias); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline void modulated_deform_conv_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) { + if (grad_output.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return modulated_deform_conv_cuda_backward( + input, + weight, + bias, + ones, + offset, + mask, + columns, + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group, + with_bias); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda.cu b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda.cu new file mode 100644 index 0000000..2072bb8 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda.cu @@ -0,0 +1,1223 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +// Original license: Apache 2.0 + +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c +// Original license: Apache 2.0 + +#include + +#include "deform_conv.h" + +#include +#include + +namespace detectron2 { + +void deformable_im2col( + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor data_col); + +void deformable_col2im( + const at::Tensor data_col, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check( + at::Tensor input, + at::Tensor offset, + at::Tensor* gradOutput, + at::Tensor weight, + int kH, + int kW, + int dH, + int dW, + int padH, + int padW, + int dilationH, + int dilationW, + int group, + int deformable_group) { + TORCH_CHECK( + weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK( + kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", + kH, + kW); + + TORCH_CHECK( + (weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", + kH, + kW, + weight.size(2), + weight.size(3)); + + TORCH_CHECK( + dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", + dH, + dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, + dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK( + ndim == 3 || ndim == 4, + "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK( + nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, + inputHeight, + inputWidth, + nOutputPlane, + outputHeight, + outputWidth); + + TORCH_CHECK( + input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, + input.size(1)); + + TORCH_CHECK( + (inputHeight + 2 * padH >= kH && inputWidth + 2 * padW >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK( + (offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, + outputWidth, + offset.size(2), + offset.size(3)); + + TORCH_CHECK( + (offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK( + gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, + gradOutput->size(dimf)); + + TORCH_CHECK( + (gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, + outputWidth, + gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check( + input, + offset, + NULL, + weight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + at::Tensor output_buffer = at::zeros( + {batchSize / im2col_step, + nOutputPlane, + im2col_step * outputHeight, + outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), + group, + output_buffer.size(1) / group, + output_buffer.size(2), + output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col( + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), + output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), + output_buffer.size(4)}); + + output_buffer = output_buffer.view( + {batchSize / im2col_step, + nOutputPlane, + im2col_step, + outputHeight, + outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + shape_check( + input, + offset, + &gradOutput, + weight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + gradOffset = gradOffset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), + group, + gradOutput.size(1) / group, + gradOutput.size(2), + gradOutput.size(3), + gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), + 0.0f, + 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), + gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), + gradOutput.size(4), + gradOutput.size(5)}); + + deformable_col2im_coord( + columns, + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + gradOffset[elt]); + + deformable_col2im( + columns, + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check( + input, + offset, + &gradOutput, + gradWeight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = gradOutputBuffer.view( + {batchSize / im2col_step, + nOutputPlane, + im2col_step, + outputHeight, + outputWidth}); + gradOutputBuffer.copy_(gradOutput); + // gradOutput is not contiguous, so we do reshape (instead of view) next + gradOutputBuffer = gradOutputBuffer.reshape( + {batchSize / im2col_step, + nOutputPlane, + im2col_step * outputHeight, + outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col( + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + group, + gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), + gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = gradWeight.view( + {group, + gradWeight.size(0) / group, + gradWeight.size(1), + gradWeight.size(2), + gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_( + gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), + 1.0, + scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), + gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view( + {gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), + gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) { + shape_check( + input, + offset, + NULL, + weight, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group); + + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR( + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, + kernel_w, + kernel_h_, + kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR( + "Input shape and kernel channels wont match: (%d vs %d).", + channels, + channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // mask shape check + TORCH_CHECK( + (mask.size(2) == height_out && mask.size(3) == width_out), + "invalid spatial size of mask, expected height: %d width: %d, but " + "got height: %d width: %d", + height_out, + width_out, + mask.size(2), + mask.size(3)); + + TORCH_CHECK( + (mask.size(1) == deformable_group * kernel_h * kernel_w), + "invalid number of channels of mask"); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = at::zeros( + {channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view( + {output.size(0), + group, + output.size(1) / group, + output.size(2), + output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns); + + // divide into group + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view( + {weight.size(0) * weight.size(1), + weight.size(2), + weight.size(3), + weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view( + {output.size(0), + output.size(1) * output.size(2), + output.size(3), + output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) { + shape_check( + input, + offset, + &grad_output, + weight, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group); + + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR( + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, + kernel_w, + kernel_h_, + kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR( + "Input shape and kernel channels wont match: (%d vs %d).", + channels, + channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // mask shape check + TORCH_CHECK( + (mask.size(2) == height_out && mask.size(3) == width_out), + "invalid spatial size of mask, expected height: %d width: %d, but " + "got height: %d width: %d", + height_out, + width_out, + mask.size(2), + mask.size(3)); + + TORCH_CHECK( + (mask.size(1) == deformable_group * kernel_h * kernel_w), + "invalid number of channels of mask"); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros( + {channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = grad_output.view( + {grad_output.size(0), + group, + grad_output.size(1) / group, + grad_output.size(2), + grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), + 0.0f, + 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view( + {weight.size(0) * weight.size(1), + weight.size(2), + weight.size(3), + weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view( + {group, + grad_weight.size(0) / group, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view( + {grad_output.size(0) * grad_output.size(1), + grad_output.size(2), + grad_output.size(3), + grad_output.size(4)}); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000..f299c7a --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu @@ -0,0 +1,1288 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu +// Original license: Apache 2.0 +// clang-format off + +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + + +namespace { + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) { + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +} + +template +__device__ scalar_t deformable_im2col_bilinear( + const scalar_t* bottom_data, + const int data_width, + const int height, + const int width, + scalar_t h, + scalar_t w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int h, + const int w, + const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int height, + const int width, + const scalar_t* im_data, + const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel( + const int n, + const scalar_t* data_im, + const scalar_t* data_offset, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* data_col) { + CUDA_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t* data_col_ptr = data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * + // height + h_in) * width + w_in; + const scalar_t* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + // const scalar_t map_h = i * dilation_h + offset_h; + // const scalar_t map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, + // cur_width, map_h, map_w); + val = deformable_im2col_bilinear( + data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_offset, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_im) { + CUDA_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight( + cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + + +template +__global__ void deformable_col2im_coord_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_im, + const scalar_t* data_offset, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_offset) { + CUDA_KERNEL_LOOP(index, n) { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t* data_col_ptr = data_col + + deformable_group_index * channel_per_deformable_group * batch_size * + width_col * height_col; + const scalar_t* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + + +namespace detectron2 { + +void deformable_im2col( + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor data_col) { + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + at::cuda::CUDAGuard device_guard(data_im.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_im_, + data_offset_, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + channels, + deformable_group, + height_col, + width_col, + data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + + +void deformable_col2im( + const at::Tensor data_col, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_im) { + // todo: make sure parallel_imgs is passed in correctly + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_offset_, + channels, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + deformable_group, + height_col, + width_col, + grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + + +void deformable_col2im_coord( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_offset) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * + deformable_group * parallel_imgs; + int channel_per_deformable_group = + channels * ksize_h * ksize_w / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_im_, + data_offset_, + channels, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + 2 * ksize_h * ksize_w * deformable_group, + deformable_group, + height_col, + width_col, + grad_offset_); + })); +} + +} // namespace detectron2 + + +template +__device__ scalar_t dmcn_im2col_bilinear( + const scalar_t* bottom_data, + const int data_width, + const int height, + const int width, + scalar_t h, + scalar_t w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int h, + const int w, + const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int height, + const int width, + const scalar_t* im_data, + const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel( + const int n, + const scalar_t* data_im, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* data_col) { + CUDA_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t* data_col_ptr = data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * + // height + h_in) * width + w_in; + const scalar_t* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + const scalar_t* data_mask_ptr = data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + // if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + // const float map_h = i * dilation_h + offset_h; + // const float map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, + // cur_width, map_h, map_w); + val = dmcn_im2col_bilinear( + data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + // data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_im) { + CUDA_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const scalar_t* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * + height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight( + cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_im, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_offset, + scalar_t* grad_mask) { + CUDA_KERNEL_LOOP(index, n) { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t* data_col_ptr = data_col + + deformable_group_index * channel_per_deformable_group * batch_size * + width_col * height_col; + const scalar_t* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const scalar_t* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * + height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + dmcn_im2col_bilinear( + data_im_ptr + cnt * height * width, + width, + height, + width, + inv_h, + inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + + // deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * + // height_col + h) * width_col + w], mask_req, mval); + grad_mask + [(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + + +namespace detectron2 { + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + at::cuda::CUDAGuard device_guard(data_im.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kenerl_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + channels, + deformable_group, + height_col, + width_col, + data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_im2col_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_im) { + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = + channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_offset_, + data_mask_, + channels, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + deformable_group, + height_col, + width_col, + grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_col2im_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, + at::Tensor grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * + kernel_w * deformable_group; + const int channel_per_deformable_group = + channels * kernel_h * kernel_w / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* grad_offset_ = grad_offset.data_ptr(); + scalar_t* grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_im_, + data_offset_, + data_mask_, + channels, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + 2 * kernel_h * kernel_w * deformable_group, + deformable_group, + height_col, + width_col, + grad_offset_, + grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_col2im_coord_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated.h b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated.h new file mode 100644 index 0000000..12aca38 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated.h @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor nms_rotated_cpu( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor nms_rotated_cuda( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold); +#endif + +// Interface for Python +// inline is needed to prevent multiple function definitions when this header is +// included by different cpps +inline at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + assert(dets.device().is_cuda() == scores.device().is_cuda()); + if (dets.device().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return nms_rotated_cuda( + dets.contiguous(), scores.contiguous(), iou_threshold); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + + return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp new file mode 100644 index 0000000..d7556e6 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "../box_iou_rotated/box_iou_rotated_utils.h" +#include "nms_rotated.h" + +namespace detectron2 { + +template +at::Tensor nms_rotated_cpu_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + // nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel, + // however, the code in this function is much shorter because + // we delegate the IoU computation for rotated boxes to + // the single_box_iou_rotated function in box_iou_rotated_utils.h + AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor"); + AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor"); + AT_ASSERTM( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) { + continue; + } + + keep[num_to_keep++] = i; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) { + continue; + } + + auto ovr = single_box_iou_rotated( + dets[i].data_ptr(), dets[j].data_ptr()); + if (ovr >= iou_threshold) { + suppressed[j] = 1; + } + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor nms_rotated_cpu( + // input must be contiguous + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] { + result = nms_rotated_cpu_kernel(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu new file mode 100644 index 0000000..2a3db5c --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu @@ -0,0 +1,145 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include +#ifdef WITH_CUDA +#include "../box_iou_rotated/box_iou_rotated_utils.h" +#endif +// TODO avoid this when pytorch supports "same directory" hipification +#ifdef WITH_HIP +#include "box_iou_rotated/box_iou_rotated_utils.h" +#endif + +using namespace detectron2; + +namespace { +int const threadsPerBlock = sizeof(unsigned long long) * 8; +} + +template +__global__ void nms_rotated_cuda_kernel( + const int n_boxes, + const double iou_threshold, + const T* dev_boxes, + unsigned long long* dev_mask) { + // nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + // Compared to nms_cuda_kernel, where each box is represented with 4 values + // (x1, y1, x2, y2), each rotated box is represented with 5 values + // (x_center, y_center, width, height, angle_degrees) here. + __shared__ T block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + // Instead of devIoU used by original horizontal nms, here + // we use the single_box_iou_rotated function from box_iou_rotated_utils.h + if (single_box_iou_rotated(cur_box, block_boxes + i * 5) > + iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +namespace detectron2 { + +at::Tensor nms_rotated_cuda( + // input must be contiguous + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + // using scalar_t = float; + AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor"); + AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor"); + at::cuda::CUDAGuard device_guard(dets.device()); + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t); + + auto dets_num = dets.size(0); + + const int col_blocks = + at::cuda::ATenCeilDiv(static_cast(dets_num), threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES( + dets_sorted.scalar_type(), "nms_rotated_kernel_cuda", [&] { + nms_rotated_cuda_kernel<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = + (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = + at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < dets_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) + .to(order_t.device(), keep.scalar_type())}); +} + +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/csrc/vision.cpp b/data_processing/detectron2/detectron2/layers/csrc/vision.cpp new file mode 100644 index 0000000..c9a2cd4 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/csrc/vision.cpp @@ -0,0 +1,117 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include +#include "ROIAlignRotated/ROIAlignRotated.h" +#include "box_iou_rotated/box_iou_rotated.h" +#include "cocoeval/cocoeval.h" +#include "deformable/deform_conv.h" +#include "nms_rotated/nms_rotated.h" + +namespace detectron2 { + +#if defined(WITH_CUDA) || defined(WITH_HIP) +extern int get_cudart_version(); +#endif + +std::string get_cuda_version() { +#if defined(WITH_CUDA) || defined(WITH_HIP) + std::ostringstream oss; + +#if defined(WITH_CUDA) + oss << "CUDA "; +#else + oss << "HIP "; +#endif + + // copied from + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 + auto printCudaStyleVersion = [&](int v) { + oss << (v / 1000) << "." << (v / 10 % 100); + if (v % 10 != 0) { + oss << "." << (v % 10); + } + }; + printCudaStyleVersion(get_cudart_version()); + return oss.str(); +#else // neither CUDA nor HIP + return std::string("not available"); +#endif +} + +bool has_cuda() { +#if defined(WITH_CUDA) + return true; +#else + return false; +#endif +} + +// similar to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp +std::string get_compiler_version() { + std::ostringstream ss; +#if defined(__GNUC__) +#ifndef __clang__ + +#if ((__GNUC__ <= 4) && (__GNUC_MINOR__ <= 8)) +#error "GCC >= 4.9 is required!" +#endif + + { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } +#endif +#endif + +#if defined(__clang_major__) + { + ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." + << __clang_patchlevel__; + } +#endif + +#if defined(_MSC_VER) + { ss << "MSVC " << _MSC_FULL_VER; } +#endif + return ss.str(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_cuda_version", &get_cuda_version, "get_cuda_version"); + m.def("has_cuda", &has_cuda, "has_cuda"); + + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); + m.def( + "deform_conv_backward_input", + &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def( + "deform_conv_backward_filter", + &deform_conv_backward_filter, + "deform_conv_backward_filter"); + m.def( + "modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated_deform_conv_forward"); + m.def( + "modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated_deform_conv_backward"); + + m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate"); + m.def( + "COCOevalEvaluateImages", + &COCOeval::EvaluateImages, + "COCOeval::EvaluateImages"); + pybind11::class_(m, "InstanceAnnotation") + .def(pybind11::init()); + pybind11::class_(m, "ImageEvaluation") + .def(pybind11::init<>()); +} + +TORCH_LIBRARY(detectron2, m) { + m.def("nms_rotated", &nms_rotated); + m.def("box_iou_rotated", &box_iou_rotated); + m.def("roi_align_rotated_forward", &ROIAlignRotated_forward); + m.def("roi_align_rotated_backward", &ROIAlignRotated_backward); +} +} // namespace detectron2 diff --git a/data_processing/detectron2/detectron2/layers/deform_conv.py b/data_processing/detectron2/detectron2/layers/deform_conv.py new file mode 100644 index 0000000..dffb720 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/deform_conv.py @@ -0,0 +1,514 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from functools import lru_cache +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair +from torchvision.ops import deform_conv2d + +from detectron2.utils.develop import create_dummy_class, create_dummy_func + +from .wrappers import _NewEmptyTensorOp + + +class _DeformConv(Function): + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64, + ): + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()) + ) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + _DeformConv._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride) + ) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + # TODO: let torchvision support full features of our deformconv. + if deformable_groups != 1: + raise NotImplementedError( + "Deformable Conv with deformable_groups != 1 is not supported on CPUs!" + ) + return deform_conv2d( + input, offset, weight, stride=stride, padding=padding, dilation=dilation + ) + else: + cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + _C.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + else: + cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + _C.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + _C.deform_conv_backward_filter( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step, + ) + + return grad_input, grad_offset, grad_weight, None, None, None, None, None, None + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + "convolution input is too small (output would be {})".format( + "x".join(map(str, output_size)) + ) + ) + return output_size + + @staticmethod + @lru_cache(maxsize=128) + def _cal_im2col_step(input_size, default_size): + """ + Calculate proper im2col step size, which should be divisible by input_size and not larger + than prefer_size. Meanwhile the step size should be as large as possible to be more + efficient. So we choose the largest one among all divisors of input_size which are smaller + than prefer_size. + :param input_size: input batch size . + :param default_size: default preferred im2col step size. + :return: the largest proper step size. + """ + if input_size <= default_size: + return input_size + best_step = 1 + for step in range(2, min(int(math.sqrt(input_size)) + 1, default_size)): + if input_size % step == 0: + if input_size // step <= default_size: + return input_size // step + best_step = step + + return best_step + + +class _ModulatedDeformConv(Function): + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + if ( + weight.requires_grad + or mask.requires_grad + or offset.requires_grad + or input.requires_grad + ): + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(_ModulatedDeformConv._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + _C.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + _C.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + if not ctx.with_bias: + grad_bias = None + + return ( + grad_input, + grad_offset, + grad_mask, + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = ( + height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1) + ) // ctx.stride + 1 + width_out = ( + width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1) + ) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = _DeformConv.apply +modulated_deform_conv = _ModulatedDeformConv.apply + + +class DeformConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False, + norm=None, + activation=None, + ): + """ + Deformable convolution from :paper:`deformconv`. + + Arguments are similar to :class:`Conv2D`. Extra arguments: + + Args: + deformable_groups (int): number of groups used in deformable convolution. + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + """ + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format( + in_channels, groups + ) + assert ( + out_channels % groups == 0 + ), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + self.norm = norm + self.activation = activation + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size) + ) + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + + def forward(self, x, offset): + if x.numel() == 0: + # When input is empty, we want to return a empty tensor with "correct" shape, + # So that the following operations will not panic + # if they check for the shape of the tensor. + # This computes the height and width of the output tensor + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + x = deform_conv( + x, + offset, + self.weight, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + def extra_repr(self): + tmpstr = "in_channels=" + str(self.in_channels) + tmpstr += ", out_channels=" + str(self.out_channels) + tmpstr += ", kernel_size=" + str(self.kernel_size) + tmpstr += ", stride=" + str(self.stride) + tmpstr += ", padding=" + str(self.padding) + tmpstr += ", dilation=" + str(self.dilation) + tmpstr += ", groups=" + str(self.groups) + tmpstr += ", deformable_groups=" + str(self.deformable_groups) + tmpstr += ", bias=False" + return tmpstr + + +class ModulatedDeformConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True, + norm=None, + activation=None, + ): + """ + Modulated deformable convolution from :paper:`deformconv2`. + + Arguments are similar to :class:`Conv2D`. Extra arguments: + + Args: + deformable_groups (int): number of groups used in deformable convolution. + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + """ + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + self.norm = norm + self.activation = activation + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, x, offset, mask): + if x.numel() == 0: + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + x = modulated_deform_conv( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + def extra_repr(self): + tmpstr = "in_channels=" + str(self.in_channels) + tmpstr += ", out_channels=" + str(self.out_channels) + tmpstr += ", kernel_size=" + str(self.kernel_size) + tmpstr += ", stride=" + str(self.stride) + tmpstr += ", padding=" + str(self.padding) + tmpstr += ", dilation=" + str(self.dilation) + tmpstr += ", groups=" + str(self.groups) + tmpstr += ", deformable_groups=" + str(self.deformable_groups) + tmpstr += ", bias=" + str(self.with_bias) + return tmpstr + + +try: + from detectron2 import _C +except ImportError: + # TODO: register ops natively so there is no need to import _C. + _msg = "detectron2 is not compiled successfully, please build following the instructions!" + _args = ("detectron2._C", _msg) + DeformConv = create_dummy_class("DeformConv", *_args) + ModulatedDeformConv = create_dummy_class("ModulatedDeformConv", *_args) + deform_conv = create_dummy_func("deform_conv", *_args) + modulated_deform_conv = create_dummy_func("modulated_deform_conv", *_args) diff --git a/data_processing/detectron2/detectron2/layers/losses.py b/data_processing/detectron2/detectron2/layers/losses.py new file mode 100644 index 0000000..850a852 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/losses.py @@ -0,0 +1,133 @@ +import math +import torch + + +def diou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Distance Intersection over Union Loss (Zhaohui Zheng et. al) + https://arxiv.org/abs/1911.08287 + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # TODO: use torch._assert_async() when pytorch 1.8 support is dropped + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # Eqn. (7) + loss = 1 - iou + (distance / diag_len) + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss + + +def ciou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Complete Intersection over Union Loss (Zhaohui Zheng et. al) + https://arxiv.org/abs/1911.08287 + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # TODO: use torch._assert_async() when pytorch 1.8 support is dropped + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # width and height of boxes + w_pred = x2 - x1 + h_pred = y2 - y1 + w_gt = x2g - x1g + h_gt = y2g - y1g + v = (4 / (math.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + with torch.no_grad(): + alpha = v / (1 - iou + v + eps) + + # Eqn. (10) + loss = 1 - iou + (distance / diag_len) + alpha * v + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/data_processing/detectron2/detectron2/layers/mask_ops.py b/data_processing/detectron2/detectron2/layers/mask_ops.py new file mode 100644 index 0000000..990d04a --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/mask_ops.py @@ -0,0 +1,275 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Tuple +import torch +from PIL import Image +from torch.nn import functional as F + +__all__ = ["paste_masks_in_image"] + + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): + """ + Args: + masks: N, 1, H, W + boxes: N, 4 + img_h, img_w (int): + skip_empty (bool): only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + if skip_empty == False, a mask of shape (N, img_h, img_w) + if skip_empty == True, a mask of shape (N, h', w'), and the slice + object for the corresponding region. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + + if skip_empty and not torch.jit.is_scripting(): + x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( + dtype=torch.int32 + ) + x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + if not torch.jit.is_scripting(): + if not masks.dtype.is_floating_point: + masks = masks.float() + img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) + + if skip_empty and not torch.jit.is_scripting(): + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () + + +# Annotate boxes as Tensor (but not Boxes) in order to use scripting +@torch.jit.script_if_tracing +def paste_masks_in_image( + masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5 +): + """ + Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. + The location, height, and width for pasting each mask is determined by their + corresponding bounding boxes in boxes. + + Note: + This is a complicated but more accurate implementation. In actual deployment, it is + often enough to use a faster but less accurate implementation. + See :func:`paste_mask_in_image_old` in this file for an alternative implementation. + + Args: + masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of + detected object instances in the image and Hmask, Wmask are the mask width and mask + height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1]. + boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4). + boxes[i] and masks[i] correspond to the same object instance. + image_shape (tuple): height, width + threshold (float): A threshold in [0, 1] for converting the (soft) masks to + binary masks. + + Returns: + img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the + number of detected object instances and Himage, Wimage are the image width + and height. img_masks[i] is a binary mask for object instance i. + """ + + assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported" + N = len(masks) + if N == 0: + return masks.new_empty((0,) + image_shape, dtype=torch.uint8) + if not isinstance(boxes, torch.Tensor): + boxes = boxes.tensor + device = boxes.device + assert len(boxes) == N, boxes.shape + + img_h, img_w = image_shape + + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == "cpu" or torch.jit.is_scripting(): + # CPU is most efficient when they are pasted one by one with skip_empty=True + # so that it performs minimal number of operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, but may have memory issue + # int(img_h) because shape may be tensors in tracing + num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) + assert ( + num_chunks <= N + ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it" + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + img_masks = torch.zeros( + N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8 + ) + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu" + ) + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + if torch.jit.is_scripting(): # Scripting does not use the optimized codepath + img_masks[inds] = masks_chunk + else: + img_masks[(inds,) + spatial_inds] = masks_chunk + return img_masks + + +# The below are the original paste function (from Detectron1) which has +# larger quantization error. +# It is faster on CPU, while the aligned one is faster on GPU thanks to grid_sample. + + +def paste_mask_in_image_old(mask, box, img_h, img_w, threshold): + """ + Paste a single mask in an image. + This is a per-box implementation of :func:`paste_masks_in_image`. + This function has larger quantization error due to incorrect pixel + modeling and is not used any more. + + Args: + mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single + object instance. Values are in [0, 1]. + box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners + of the object instance. + img_h, img_w (int): Image height and width. + threshold (float): Mask binarization threshold in [0, 1]. + + Returns: + im_mask (Tensor): + The resized and binarized object mask pasted into the original + image plane (a tensor of shape (img_h, img_w)). + """ + # Conversion from continuous box coordinates to discrete pixel coordinates + # via truncation (cast to int32). This determines which pixels to paste the + # mask onto. + box = box.to(dtype=torch.int32) # Continuous to discrete coordinate conversion + # An example (1D) box with continuous coordinates (x0=0.7, x1=4.3) will map to + # a discrete coordinates (x0=0, x1=4). Note that box is mapped to 5 = x1 - x0 + 1 + # pixels (not x1 - x0 pixels). + samples_w = box[2] - box[0] + 1 # Number of pixel samples, *not* geometric width + samples_h = box[3] - box[1] + 1 # Number of pixel samples, *not* geometric height + + # Resample the mask from it's original grid to the new samples_w x samples_h grid + mask = Image.fromarray(mask.cpu().numpy()) + mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR) + mask = np.array(mask, copy=False) + + if threshold >= 0: + mask = np.array(mask > threshold, dtype=np.uint8) + mask = torch.from_numpy(mask) + else: + # for visualization and debugging, we also + # allow it to return an unmodified mask + mask = torch.from_numpy(mask * 255).to(torch.uint8) + + im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, img_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, img_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + return im_mask + + +# Our pixel modeling requires extrapolation for any continuous +# coordinate < 0.5 or > length - 0.5. When sampling pixels on the masks, +# we would like this extrapolation to be an interpolation between boundary values and zero, +# instead of using absolute zero or boundary values. +# Therefore `paste_mask_in_image_old` is often used with zero padding around the masks like this: +# masks, scale = pad_masks(masks[:, 0, :, :], 1) +# boxes = scale_boxes(boxes.tensor, scale) + + +def pad_masks(masks, padding): + """ + Args: + masks (tensor): A tensor of shape (B, M, M) representing B masks. + padding (int): Number of cells to pad on all sides. + + Returns: + The padded masks and the scale factor of the padding size / original size. + """ + B = masks.shape[0] + M = masks.shape[-1] + pad2 = 2 * padding + scale = float(M + pad2) / M + padded_masks = masks.new_zeros((B, M + pad2, M + pad2)) + padded_masks[:, padding:-padding, padding:-padding] = masks + return padded_masks, scale + + +def scale_boxes(boxes, scale): + """ + Args: + boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4 + coords representing the corners x0, y0, x1, y1, + scale (float): The box scaling factor. + + Returns: + Scaled boxes. + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 + + w_half *= scale + h_half *= scale + + scaled_boxes = torch.zeros_like(boxes) + scaled_boxes[:, 0] = x_c - w_half + scaled_boxes[:, 2] = x_c + w_half + scaled_boxes[:, 1] = y_c - h_half + scaled_boxes[:, 3] = y_c + h_half + return scaled_boxes + + +@torch.jit.script_if_tracing +def _paste_masks_tensor_shape( + masks: torch.Tensor, + boxes: torch.Tensor, + image_shape: Tuple[torch.Tensor, torch.Tensor], + threshold: float = 0.5, +): + """ + A wrapper of paste_masks_in_image where image_shape is Tensor. + During tracing, shapes might be tensors instead of ints. The Tensor->int + conversion should be scripted rather than traced. + """ + return paste_masks_in_image(masks, boxes, (int(image_shape[0]), int(image_shape[1])), threshold) diff --git a/data_processing/detectron2/detectron2/layers/nms.py b/data_processing/detectron2/detectron2/layers/nms.py new file mode 100644 index 0000000..1019e7f --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/nms.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +from torchvision.ops import boxes as box_ops +from torchvision.ops import nms # noqa . for compatibility + + +def batched_nms( + boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float +): + """ + Same as torchvision.ops.boxes.batched_nms, but with float(). + """ + assert boxes.shape[-1] == 4 + # Note: Torchvision already has a strategy (https://github.com/pytorch/vision/issues/1311) + # to decide whether to use coordinate trick or for loop to implement batched_nms. So we + # just call it directly. + # Fp16 does not have enough range for batched NMS, so adding float(). + return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) + + +# Note: this function (nms_rotated) might be moved into +# torchvision/ops/boxes.py in the future +def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float): + """ + Performs non-maximum suppression (NMS) on the rotated boxes according + to their intersection-over-union (IoU). + + Rotated NMS iteratively removes lower scoring rotated boxes which have an + IoU greater than iou_threshold with another (higher scoring) rotated box. + + Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as + RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they + can be representing completely different objects in certain tasks, e.g., OCR. + + As for the question of whether rotated-NMS should treat them as faraway boxes + even though their IOU is 1, it depends on the application and/or ground truth annotation. + + As an extreme example, consider a single character v and the square box around it. + + If the angle is 0 degree, the object (text) would be read as 'v'; + + If the angle is 90 degrees, the object (text) would become '>'; + + If the angle is 180 degrees, the object (text) would become '^'; + + If the angle is 270/-90 degrees, the object (text) would become '<' + + All of these cases have IoU of 1 to each other, and rotated NMS that only + uses IoU as criterion would only keep one of them with the highest score - + which, practically, still makes sense in most cases because typically + only one of theses orientations is the correct one. Also, it does not matter + as much if the box is only used to classify the object (instead of transcribing + them with a sequential OCR recognition model) later. + + On the other hand, when we use IoU to filter proposals that are close to the + ground truth during training, we should definitely take the angle into account if + we know the ground truth is labeled with the strictly correct orientation (as in, + upside-down words are annotated with -180 degrees even though they can be covered + with a 0/90/-90 degree box, etc.) + + The way the original dataset is annotated also matters. For example, if the dataset + is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, + we can estimate a minimum rotated bounding box to this polygon, but there's no way + we can tell the correct angle with 100% confidence (as shown above, there could be 4 different + rotated boxes, with angles differed by 90 degrees to each other, covering the exactly + same region). In that case we have to just use IoU to determine the box + proximity (as many detection benchmarks (even for text) do) unless there're other + assumptions we can make (like width is always larger than height, or the object is not + rotated by more than 90 degrees CCW/CW, etc.) + + In summary, not considering angles in rotated NMS seems to be a good option for now, + but we should be aware of its implications. + + Args: + boxes (Tensor[N, 5]): Rotated boxes to perform NMS on. They are expected to be in + (x_center, y_center, width, height, angle_degrees) format. + scores (Tensor[N]): Scores for each one of the rotated boxes + iou_threshold (float): Discards all overlapping rotated boxes with IoU < iou_threshold + + Returns: + keep (Tensor): int64 tensor with the indices of the elements that have been kept + by Rotated NMS, sorted in decreasing order of scores + """ + return torch.ops.detectron2.nms_rotated(boxes, scores, iou_threshold) + + +# Note: this function (batched_nms_rotated) might be moved into +# torchvision/ops/boxes.py in the future + + +@torch.jit.script_if_tracing +def batched_nms_rotated( + boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float +): + """ + Performs non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Args: + boxes (Tensor[N, 5]): + boxes where NMS will be performed. They + are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format + scores (Tensor[N]): + scores for each one of the boxes + idxs (Tensor[N]): + indices of the categories for each one of the boxes. + iou_threshold (float): + discards all overlapping boxes + with IoU < iou_threshold + + Returns: + Tensor: + int64 tensor with the indices of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + assert boxes.shape[-1] == 5 + + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + boxes = boxes.float() # fp16 does not have enough range for batched NMS + # Strategy: in order to perform NMS independently per class, + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + + # Note that batched_nms in torchvision/ops/boxes.py only uses max_coordinate, + # which won't handle negative coordinates correctly. + # Here by using min_coordinate we can make sure the negative coordinates are + # correctly handled. + max_coordinate = ( + torch.max(boxes[:, 0], boxes[:, 1]) + torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).max() + min_coordinate = ( + torch.min(boxes[:, 0], boxes[:, 1]) - torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).min() + offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1) + boxes_for_nms = boxes.clone() # avoid modifying the original values in boxes + boxes_for_nms[:, :2] += offsets[:, None] + keep = nms_rotated(boxes_for_nms, scores, iou_threshold) + return keep diff --git a/data_processing/detectron2/detectron2/layers/roi_align.py b/data_processing/detectron2/detectron2/layers/roi_align.py new file mode 100644 index 0000000..163462e --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/roi_align.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from torch import nn +from torchvision.ops import roi_align + + +# NOTE: torchvision's RoIAlign has a different default aligned=False +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True): + """ + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + aligned (bool): if False, use the legacy implementation in + Detectron. If True, align the results more perfectly. + + Note: + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). But the original + roi_align (aligned=False) does not subtract the 0.5 when computing neighboring + pixel indices and therefore it uses pixels with a slightly incorrect alignment + (relative to our pixel model) when performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; see + detectron2/tests/test_roi_align.py for verification. + + The difference does not make a difference to the model's performance if + ROIAlign is used together with conv layers. + """ + super().__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + + from torchvision import __version__ + + version = tuple(int(x) for x in __version__.split(".")[:2]) + # https://github.com/pytorch/vision/pull/2438 + assert version >= (0, 7), "Require torchvision >= 0.7" + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. + """ + assert rois.dim() == 2 and rois.size(1) == 5 + if input.is_quantized: + input = input.dequantize() + return roi_align( + input, + rois.to(dtype=input.dtype), + self.output_size, + self.spatial_scale, + self.sampling_ratio, + self.aligned, + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" + return tmpstr diff --git a/data_processing/detectron2/detectron2/layers/roi_align_rotated.py b/data_processing/detectron2/detectron2/layers/roi_align_rotated.py new file mode 100644 index 0000000..2a52399 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/roi_align_rotated.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + + +class _ROIAlignRotated(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): + ctx.save_for_backward(roi) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.size() + output = torch.ops.detectron2.roi_align_rotated_forward( + input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + (rois,) = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + grad_input = torch.ops.detectron2.roi_align_rotated_backward( + grad_output, + rois, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + sampling_ratio, + ) + return grad_input, None, None, None, None, None + + +roi_align_rotated = _ROIAlignRotated.apply + + +class ROIAlignRotated(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + """ + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + + Note: + ROIAlignRotated supports continuous coordinate by default: + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). + """ + super(ROIAlignRotated, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx6 boxes. First column is the index into N. + The other 5 columns are (x_ctr, y_ctr, width, height, angle_degrees). + """ + assert rois.dim() == 2 and rois.size(1) == 6 + orig_dtype = input.dtype + if orig_dtype == torch.float16: + input = input.float() + rois = rois.float() + output_size = _pair(self.output_size) + + # Scripting for Autograd is currently unsupported. + # This is a quick fix without having to rewrite code on the C++ side + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.ops.detectron2.roi_align_rotated_forward( + input, rois, self.spatial_scale, output_size[0], output_size[1], self.sampling_ratio + ).to(dtype=orig_dtype) + + return roi_align_rotated( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ).to(dtype=orig_dtype) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr diff --git a/data_processing/detectron2/detectron2/layers/rotated_boxes.py b/data_processing/detectron2/detectron2/layers/rotated_boxes.py new file mode 100644 index 0000000..03f73b3 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/rotated_boxes.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import, division, print_function, unicode_literals +import torch + + +def pairwise_iou_rotated(boxes1, boxes2): + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in + (x_center, y_center, width, height, angle) format. + + Arguments: + boxes1 (Tensor[N, 5]) + boxes2 (Tensor[M, 5]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2) diff --git a/data_processing/detectron2/detectron2/layers/shape_spec.py b/data_processing/detectron2/detectron2/layers/shape_spec.py new file mode 100644 index 0000000..8dac3c5 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/shape_spec.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ShapeSpec: + """ + A simple structure that contains basic shape specification about a tensor. + It is often used as the auxiliary inputs/outputs of models, + to complement the lack of shape inference ability among pytorch modules. + """ + + channels: Optional[int] = None + height: Optional[int] = None + width: Optional[int] = None + stride: Optional[int] = None diff --git a/data_processing/detectron2/detectron2/layers/wrappers.py b/data_processing/detectron2/detectron2/layers/wrappers.py new file mode 100644 index 0000000..fb3cb38 --- /dev/null +++ b/data_processing/detectron2/detectron2/layers/wrappers.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Wrappers around on some nn functions, mainly to support empty tensors. + +Ideally, add support directly in PyTorch to empty tensors in those functions. + +These can be removed once https://github.com/pytorch/pytorch/issues/12013 +is implemented +""" + +import warnings +from typing import List, Optional +import torch +from torch.nn import functional as F + +from detectron2.utils.env import TORCH_VERSION + + +def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor: + """ + Turn a list of integer scalars or integer Tensor scalars into a vector, + in a way that's both traceable and scriptable. + + In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs. + In scripting or eager, `x` should be a list of int. + """ + if torch.jit.is_scripting(): + return torch.as_tensor(x, device=device) + if torch.jit.is_tracing(): + assert all( + [isinstance(t, torch.Tensor) for t in x] + ), "Shape should be tensor during tracing!" + # as_tensor should not be used in tracing because it records a constant + ret = torch.stack(x) + if ret.device != device: # avoid recording a hard-coded device if not necessary + ret = ret.to(device=device) + return ret + return torch.as_tensor(x, device=device) + + +def check_if_dynamo_compiling(): + if TORCH_VERSION >= (1, 14): + from torch._dynamo import is_compiling + + return is_compiling() + else: + return False + + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def empty_input_loss_func_wrapper(loss_func): + def wrapped_loss_func(input, target, *, reduction="mean", **kwargs): + """ + Same as `loss_func`, but returns 0 (instead of nan) for empty inputs. + """ + if target.numel() == 0 and reduction == "mean": + return input.sum() * 0.0 # connect the gradient + return loss_func(input, target, reduction=reduction, **kwargs) + + return wrapped_loss_func + + +cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy) + + +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + # torchscript does not support SyncBatchNorm yet + # https://github.com/pytorch/pytorch/issues/40507 + # and we skip these codes in torchscript since: + # 1. currently we only support torchscript in evaluation mode + # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or + # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. + if not torch.jit.is_scripting(): + # Dynamo doesn't support context managers yet + is_dynamo_compiling = check_if_dynamo_compiling() + if not is_dynamo_compiling: + with warnings.catch_warnings(record=True): + if x.numel() == 0 and self.training: + # https://github.com/pytorch/pytorch/issues/12013 + assert not isinstance( + self.norm, torch.nn.SyncBatchNorm + ), "SyncBatchNorm does not support empty inputs!" + + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +ConvTranspose2d = torch.nn.ConvTranspose2d +BatchNorm2d = torch.nn.BatchNorm2d +interpolate = F.interpolate +Linear = torch.nn.Linear + + +def nonzero_tuple(x): + """ + A 'as_tuple=True' version of torch.nonzero to support torchscript. + because of https://github.com/pytorch/pytorch/issues/38718 + """ + if torch.jit.is_scripting(): + if x.dim() == 0: + return x.unsqueeze(0).nonzero().unbind(1) + return x.nonzero().unbind(1) + else: + return x.nonzero(as_tuple=True) + + +@torch.jit.script_if_tracing +def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: + """ + Tracing friendly way to cast tensor to another tensor's device. Device will be treated + as constant during tracing, scripting the casting process as whole can workaround this issue. + """ + return src.to(dst.device) diff --git a/data_processing/detectron2/detectron2/model_zoo/__init__.py b/data_processing/detectron2/detectron2/model_zoo/__init__.py new file mode 100644 index 0000000..6204208 --- /dev/null +++ b/data_processing/detectron2/detectron2/model_zoo/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Model Zoo API for Detectron2: a collection of functions to create common model architectures +listed in `MODEL_ZOO.md `_, +and optionally load their pre-trained weights. +""" + +from .model_zoo import get, get_config_file, get_checkpoint_url, get_config + +__all__ = ["get_checkpoint_url", "get", "get_config_file", "get_config"] diff --git a/data_processing/detectron2/detectron2/model_zoo/model_zoo.py b/data_processing/detectron2/detectron2/model_zoo/model_zoo.py new file mode 100644 index 0000000..5b90bc9 --- /dev/null +++ b/data_processing/detectron2/detectron2/model_zoo/model_zoo.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +from typing import Optional +import pkg_resources +import torch + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate +from detectron2.modeling import build_model + + +class _ModelZooUrls(object): + """ + Mapping from names to officially released Detectron2 pre-trained models. + """ + + S3_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" + + # format: {config_path.yaml} -> model_id/model_final_{commit}.pkl + CONFIG_PATH_TO_URL_SUFFIX = { + # COCO Detection with Faster R-CNN + "COCO-Detection/faster_rcnn_R_50_C4_1x": "137257644/model_final_721ade.pkl", + "COCO-Detection/faster_rcnn_R_50_DC5_1x": "137847829/model_final_51d356.pkl", + "COCO-Detection/faster_rcnn_R_50_FPN_1x": "137257794/model_final_b275ba.pkl", + "COCO-Detection/faster_rcnn_R_50_C4_3x": "137849393/model_final_f97cb7.pkl", + "COCO-Detection/faster_rcnn_R_50_DC5_3x": "137849425/model_final_68d202.pkl", + "COCO-Detection/faster_rcnn_R_50_FPN_3x": "137849458/model_final_280758.pkl", + "COCO-Detection/faster_rcnn_R_101_C4_3x": "138204752/model_final_298dad.pkl", + "COCO-Detection/faster_rcnn_R_101_DC5_3x": "138204841/model_final_3e0943.pkl", + "COCO-Detection/faster_rcnn_R_101_FPN_3x": "137851257/model_final_f6e8b1.pkl", + "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x": "139173657/model_final_68b088.pkl", + # COCO Detection with RetinaNet + "COCO-Detection/retinanet_R_50_FPN_1x": "190397773/model_final_bfca0b.pkl", + "COCO-Detection/retinanet_R_50_FPN_3x": "190397829/model_final_5bd44e.pkl", + "COCO-Detection/retinanet_R_101_FPN_3x": "190397697/model_final_971ab9.pkl", + # COCO Detection with RPN and Fast R-CNN + "COCO-Detection/rpn_R_50_C4_1x": "137258005/model_final_450694.pkl", + "COCO-Detection/rpn_R_50_FPN_1x": "137258492/model_final_02ce48.pkl", + "COCO-Detection/fast_rcnn_R_50_FPN_1x": "137635226/model_final_e5f7ce.pkl", + # COCO Instance Segmentation Baselines with Mask R-CNN + "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x": "137259246/model_final_9243eb.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x": "137260150/model_final_4f86c3.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "137260431/model_final_a54504.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x": "137849525/model_final_4ce675.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x": "137849551/model_final_84107b.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x": "137849600/model_final_f10217.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x": "138363239/model_final_a2914c.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x": "138363294/model_final_0464b7.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x": "138205316/model_final_a3ec72.pkl", + "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x": "139653917/model_final_2d9806.pkl", # noqa + # New baselines using Large-Scale Jitter and Longer Training Schedule + "new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ": "42047764/model_final_bb69de.pkl", + "new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ": "42047638/model_final_89a8d3.pkl", + "new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ": "42019571/model_final_14d201.pkl", + "new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ": "42025812/model_final_4f7b58.pkl", + "new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ": "42131867/model_final_0bb7ae.pkl", + "new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ": "42073830/model_final_f96b26.pkl", + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ": "42047771/model_final_b7fbab.pkl", # noqa + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ": "42132721/model_final_5d87c1.pkl", # noqa + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ": "42025447/model_final_f1362d.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ": "42047784/model_final_6ba57e.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ": "42047642/model_final_27b9c1.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ": "42045954/model_final_ef3a80.pkl", # noqa + # COCO Person Keypoint Detection Baselines with Keypoint R-CNN + "COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x": "137261548/model_final_04e291.pkl", + "COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x": "137849621/model_final_a6e10b.pkl", + "COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x": "138363331/model_final_997cc7.pkl", + "COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x": "139686956/model_final_5ad38f.pkl", + # COCO Panoptic Segmentation Baselines with Panoptic FPN + "COCO-PanopticSegmentation/panoptic_fpn_R_50_1x": "139514544/model_final_dbfeb4.pkl", + "COCO-PanopticSegmentation/panoptic_fpn_R_50_3x": "139514569/model_final_c10459.pkl", + "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x": "139514519/model_final_cafdb1.pkl", + # LVIS Instance Segmentation Baselines with Mask R-CNN + "LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "144219072/model_final_571f7c.pkl", # noqa + "LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x": "144219035/model_final_824ab5.pkl", # noqa + "LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x": "144219108/model_final_5e3439.pkl", # noqa + # Cityscapes & Pascal VOC Baselines + "Cityscapes/mask_rcnn_R_50_FPN": "142423278/model_final_af9cf5.pkl", + "PascalVOC-Detection/faster_rcnn_R_50_C4": "142202221/model_final_b1acc2.pkl", + # Other Settings + "Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5": "138602867/model_final_65c703.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5": "144998336/model_final_821d0b.pkl", + "Misc/cascade_mask_rcnn_R_50_FPN_1x": "138602847/model_final_e9d89b.pkl", + "Misc/cascade_mask_rcnn_R_50_FPN_3x": "144998488/model_final_480dd8.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_syncbn": "169527823/model_final_3b3c51.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_gn": "138602888/model_final_dc5d9e.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_3x_gn": "138602908/model_final_01ca85.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_9x_gn": "183808979/model_final_da7b4c.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn": "184226666/model_final_5ce33e.pkl", + "Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x": "139797668/model_final_be35db.pkl", + "Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv": "18131413/model_0039999_e76410.pkl", # noqa + # D1 Comparisons + "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x": "137781054/model_final_7ab50c.pkl", # noqa + "Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x": "137781281/model_final_62ca52.pkl", # noqa + "Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x": "137781195/model_final_cce136.pkl", + } + + @staticmethod + def query(config_path: str) -> Optional[str]: + """ + Args: + config_path: relative config filename + """ + name = config_path.replace(".yaml", "").replace(".py", "") + if name in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX: + suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[name] + return _ModelZooUrls.S3_PREFIX + name + "/" + suffix + return None + + +def get_checkpoint_url(config_path): + """ + Returns the URL to the model trained using the given config + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + + Returns: + str: a URL to the model + """ + url = _ModelZooUrls.query(config_path) + if url is None: + raise RuntimeError("Pretrained model for {} is not available!".format(config_path)) + return url + + +def get_config_file(config_path): + """ + Returns path to a builtin config file. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + + Returns: + str: the real path to the config file. + """ + cfg_file = pkg_resources.resource_filename( + "detectron2.model_zoo", os.path.join("configs", config_path) + ) + if not os.path.exists(cfg_file): + raise RuntimeError("{} not available in Model Zoo!".format(config_path)) + return cfg_file + + +def get_config(config_path, trained: bool = False): + """ + Returns a config object for a model in model zoo. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + trained (bool): If True, will set ``MODEL.WEIGHTS`` to trained model zoo weights. + If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used + instead; this will typically (though not always) initialize a subset of weights using + an ImageNet pre-trained model, while randomly initializing the other weights. + + Returns: + CfgNode or omegaconf.DictConfig: a config object + """ + cfg_file = get_config_file(config_path) + if cfg_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(cfg_file) + if trained: + cfg.MODEL.WEIGHTS = get_checkpoint_url(config_path) + return cfg + elif cfg_file.endswith(".py"): + cfg = LazyConfig.load(cfg_file) + if trained: + url = get_checkpoint_url(config_path) + if "train" in cfg and "init_checkpoint" in cfg.train: + cfg.train.init_checkpoint = url + else: + raise NotImplementedError + return cfg + + +def get(config_path, trained: bool = False, device: Optional[str] = None): + """ + Get a model specified by relative path under Detectron2's official ``configs/`` directory. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + trained (bool): see :func:`get_config`. + device (str or None): overwrite the device in config, if given. + + Returns: + nn.Module: a detectron2 model. Will be in training mode. + + Example: + :: + from detectron2 import model_zoo + model = model_zoo.get("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml", trained=True) + """ + cfg = get_config(config_path, trained) + if device is None and not torch.cuda.is_available(): + device = "cpu" + if device is not None and isinstance(cfg, CfgNode): + cfg.MODEL.DEVICE = device + + if isinstance(cfg, CfgNode): + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + model = instantiate(cfg.model) + if device is not None: + model = model.to(device) + if "train" in cfg and "init_checkpoint" in cfg.train: + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + return model diff --git a/data_processing/detectron2/detectron2/modeling/__init__.py b/data_processing/detectron2/detectron2/modeling/__init__.py new file mode 100644 index 0000000..4d949e2 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/__init__.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.layers import ShapeSpec + +from .anchor_generator import build_anchor_generator, ANCHOR_GENERATOR_REGISTRY +from .backbone import ( + BACKBONE_REGISTRY, + FPN, + Backbone, + ResNet, + ResNetBlockBase, + build_backbone, + build_resnet_backbone, + make_stage, + ViT, + SimpleFeaturePyramid, + get_vit_lr_decay_rate, + MViT, + SwinTransformer, +) +from .meta_arch import ( + META_ARCH_REGISTRY, + SEM_SEG_HEADS_REGISTRY, + GeneralizedRCNN, + PanopticFPN, + ProposalNetwork, + RetinaNet, + SemanticSegmentor, + build_model, + build_sem_seg_head, + FCOS, +) +from .postprocessing import detector_postprocess +from .proposal_generator import ( + PROPOSAL_GENERATOR_REGISTRY, + build_proposal_generator, + RPN_HEAD_REGISTRY, + build_rpn_head, +) +from .roi_heads import ( + ROI_BOX_HEAD_REGISTRY, + ROI_HEADS_REGISTRY, + ROI_KEYPOINT_HEAD_REGISTRY, + ROI_MASK_HEAD_REGISTRY, + ROIHeads, + StandardROIHeads, + BaseMaskRCNNHead, + BaseKeypointRCNNHead, + FastRCNNOutputLayers, + build_box_head, + build_keypoint_head, + build_mask_head, + build_roi_heads, +) +from .test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA +from .mmdet_wrapper import MMDetBackbone, MMDetDetector + +_EXCLUDE = {"ShapeSpec"} +__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] + + +from detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/data_processing/detectron2/detectron2/modeling/anchor_generator.py b/data_processing/detectron2/detectron2/modeling/anchor_generator.py new file mode 100644 index 0000000..ac94e72 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/anchor_generator.py @@ -0,0 +1,386 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import collections +import math +from typing import List +import torch +from torch import nn + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, move_device_like +from detectron2.structures import Boxes, RotatedBoxes +from detectron2.utils.registry import Registry + +ANCHOR_GENERATOR_REGISTRY = Registry("ANCHOR_GENERATOR") +ANCHOR_GENERATOR_REGISTRY.__doc__ = """ +Registry for modules that creates object detection anchors for feature maps. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +class BufferList(nn.Module): + """ + Similar to nn.ParameterList, but for buffers + """ + + def __init__(self, buffers): + super().__init__() + for i, buffer in enumerate(buffers): + # Use non-persistent buffer so the values are not saved in checkpoint + self.register_buffer(str(i), buffer, persistent=False) + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.values()) + + +def _create_grid_offsets( + size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor +): + grid_height, grid_width = size + shifts_x = move_device_like( + torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32), + target_device_tensor, + ) + shifts_y = move_device_like( + torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32), + target_device_tensor, + ) + + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + return shift_x, shift_y + + +def _broadcast_params(params, num_features, name): + """ + If one size (or aspect ratio) is specified and there are multiple feature + maps, we "broadcast" anchors of that single size (or aspect ratio) + over all feature maps. + + If params is list[float], or list[list[float]] with len(params) == 1, repeat + it num_features time. + + Returns: + list[list[float]]: param for each feature + """ + assert isinstance( + params, collections.abc.Sequence + ), f"{name} in anchor generator has to be a list! Got {params}." + assert len(params), f"{name} in anchor generator cannot be empty!" + if not isinstance(params[0], collections.abc.Sequence): # params is list[float] + return [params] * num_features + if len(params) == 1: + return list(params) * num_features + assert len(params) == num_features, ( + f"Got {name} of length {len(params)} in anchor generator, " + f"but the number of input features is {num_features}!" + ) + return params + + +@ANCHOR_GENERATOR_REGISTRY.register() +class DefaultAnchorGenerator(nn.Module): + """ + Compute anchors in the standard ways described in + "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks". + """ + + box_dim: torch.jit.Final[int] = 4 + """ + the dimension of each anchor box. + """ + + @configurable + def __init__(self, *, sizes, aspect_ratios, strides, offset=0.5): + """ + This interface is experimental. + + Args: + sizes (list[list[float]] or list[float]): + If ``sizes`` is list[list[float]], ``sizes[i]`` is the list of anchor sizes + (i.e. sqrt of anchor area) to use for the i-th feature map. + If ``sizes`` is list[float], ``sizes`` is used for all feature maps. + Anchor sizes are given in absolute lengths in units of + the input image; they do not dynamically scale if the input image size changes. + aspect_ratios (list[list[float]] or list[float]): list of aspect ratios + (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies. + strides (list[int]): stride of each input feature. + offset (float): Relative offset between the center of the first anchor and the top-left + corner of the image. Value has to be in [0, 1). + Recommend to use 0.5, which means half stride. + """ + super().__init__() + + self.strides = strides + self.num_features = len(self.strides) + sizes = _broadcast_params(sizes, self.num_features, "sizes") + aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios") + self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios) + + self.offset = offset + assert 0.0 <= self.offset < 1.0, self.offset + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + return { + "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES, + "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS, + "strides": [x.stride for x in input_shape], + "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET, + } + + def _calculate_anchors(self, sizes, aspect_ratios): + cell_anchors = [ + self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios) + ] + return BufferList(cell_anchors) + + @property + @torch.jit.unused + def num_cell_anchors(self): + """ + Alias of `num_anchors`. + """ + return self.num_anchors + + @property + @torch.jit.unused + def num_anchors(self): + """ + Returns: + list[int]: Each int is the number of anchors at every pixel + location, on that feature map. + For example, if at every pixel we use anchors of 3 aspect + ratios and 5 sizes, the number of anchors is 15. + (See also ANCHOR_GENERATOR.SIZES and ANCHOR_GENERATOR.ASPECT_RATIOS in config) + + In standard RPN models, `num_anchors` on every feature map is the same. + """ + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def _grid_anchors(self, grid_sizes: List[List[int]]): + """ + Returns: + list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4 + """ + anchors = [] + # buffers() not supported by torchscript. use named_buffers() instead + buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()] + for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers): + shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) + + return anchors + + def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)): + """ + Generate a tensor storing canonical anchor boxes, which are all anchor + boxes of different sizes and aspect_ratios centered at (0, 0). + We can later build the set of anchors for a full feature map by + shifting and tiling these tensors (see `meth:_grid_anchors`). + + Args: + sizes (tuple[float]): + aspect_ratios (tuple[float]]): + + Returns: + Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes + in XYXY format. + """ + + # This is different from the anchor generator defined in the original Faster R-CNN + # code or Detectron. They yield the same AP, however the old version defines cell + # anchors in a less natural way with a shift relative to the feature grid and + # quantization that results in slightly different sizes for different aspect ratios. + # See also https://github.com/facebookresearch/Detectron/issues/227 + + anchors = [] + for size in sizes: + area = size**2.0 + for aspect_ratio in aspect_ratios: + # s * s = w * h + # a = h / w + # ... some algebra ... + # w = sqrt(s * s / a) + # h = a * w + w = math.sqrt(area / aspect_ratio) + h = aspect_ratio * w + x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0 + anchors.append([x0, y0, x1, y1]) + return torch.tensor(anchors) + + def forward(self, features: List[torch.Tensor]): + """ + Args: + features (list[Tensor]): list of backbone feature maps on which to generate anchors. + + Returns: + list[Boxes]: a list of Boxes containing all the anchors for each feature map + (i.e. the cell anchors repeated over all locations in the feature map). + The number of anchors of each feature map is Hi x Wi x num_cell_anchors, + where Hi, Wi are resolution of the feature map divided by anchor stride. + """ + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) + return [Boxes(x) for x in anchors_over_all_feature_maps] + + +@ANCHOR_GENERATOR_REGISTRY.register() +class RotatedAnchorGenerator(nn.Module): + """ + Compute rotated anchors used by Rotated RPN (RRPN), described in + "Arbitrary-Oriented Scene Text Detection via Rotation Proposals". + """ + + box_dim: int = 5 + """ + the dimension of each anchor box. + """ + + @configurable + def __init__(self, *, sizes, aspect_ratios, strides, angles, offset=0.5): + """ + This interface is experimental. + + Args: + sizes (list[list[float]] or list[float]): + If sizes is list[list[float]], sizes[i] is the list of anchor sizes + (i.e. sqrt of anchor area) to use for the i-th feature map. + If sizes is list[float], the sizes are used for all feature maps. + Anchor sizes are given in absolute lengths in units of + the input image; they do not dynamically scale if the input image size changes. + aspect_ratios (list[list[float]] or list[float]): list of aspect ratios + (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies. + strides (list[int]): stride of each input feature. + angles (list[list[float]] or list[float]): list of angles (in degrees CCW) + to use for anchors. Same "broadcast" rule for `sizes` applies. + offset (float): Relative offset between the center of the first anchor and the top-left + corner of the image. Value has to be in [0, 1). + Recommend to use 0.5, which means half stride. + """ + super().__init__() + + self.strides = strides + self.num_features = len(self.strides) + sizes = _broadcast_params(sizes, self.num_features, "sizes") + aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios") + angles = _broadcast_params(angles, self.num_features, "angles") + self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios, angles) + + self.offset = offset + assert 0.0 <= self.offset < 1.0, self.offset + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + return { + "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES, + "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS, + "strides": [x.stride for x in input_shape], + "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET, + "angles": cfg.MODEL.ANCHOR_GENERATOR.ANGLES, + } + + def _calculate_anchors(self, sizes, aspect_ratios, angles): + cell_anchors = [ + self.generate_cell_anchors(size, aspect_ratio, angle).float() + for size, aspect_ratio, angle in zip(sizes, aspect_ratios, angles) + ] + return BufferList(cell_anchors) + + @property + def num_cell_anchors(self): + """ + Alias of `num_anchors`. + """ + return self.num_anchors + + @property + def num_anchors(self): + """ + Returns: + list[int]: Each int is the number of anchors at every pixel + location, on that feature map. + For example, if at every pixel we use anchors of 3 aspect + ratios, 2 sizes and 5 angles, the number of anchors is 30. + (See also ANCHOR_GENERATOR.SIZES, ANCHOR_GENERATOR.ASPECT_RATIOS + and ANCHOR_GENERATOR.ANGLES in config) + + In standard RRPN models, `num_anchors` on every feature map is the same. + """ + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def _grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors): + shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors) + zeros = torch.zeros_like(shift_x) + shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1) + + anchors.append((shifts.view(-1, 1, 5) + base_anchors.view(1, -1, 5)).reshape(-1, 5)) + + return anchors + + def generate_cell_anchors( + self, + sizes=(32, 64, 128, 256, 512), + aspect_ratios=(0.5, 1, 2), + angles=(-90, -60, -30, 0, 30, 60, 90), + ): + """ + Generate a tensor storing canonical anchor boxes, which are all anchor + boxes of different sizes, aspect_ratios, angles centered at (0, 0). + We can later build the set of anchors for a full feature map by + shifting and tiling these tensors (see `meth:_grid_anchors`). + + Args: + sizes (tuple[float]): + aspect_ratios (tuple[float]]): + angles (tuple[float]]): + + Returns: + Tensor of shape (len(sizes) * len(aspect_ratios) * len(angles), 5) + storing anchor boxes in (x_ctr, y_ctr, w, h, angle) format. + """ + anchors = [] + for size in sizes: + area = size**2.0 + for aspect_ratio in aspect_ratios: + # s * s = w * h + # a = h / w + # ... some algebra ... + # w = sqrt(s * s / a) + # h = a * w + w = math.sqrt(area / aspect_ratio) + h = aspect_ratio * w + anchors.extend([0, 0, w, h, a] for a in angles) + + return torch.tensor(anchors) + + def forward(self, features): + """ + Args: + features (list[Tensor]): list of backbone feature maps on which to generate anchors. + + Returns: + list[RotatedBoxes]: a list of Boxes containing all the anchors for each feature map + (i.e. the cell anchors repeated over all locations in the feature map). + The number of anchors of each feature map is Hi x Wi x num_cell_anchors, + where Hi, Wi are resolution of the feature map divided by anchor stride. + """ + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) + return [RotatedBoxes(x) for x in anchors_over_all_feature_maps] + + +def build_anchor_generator(cfg, input_shape): + """ + Built an anchor generator from `cfg.MODEL.ANCHOR_GENERATOR.NAME`. + """ + anchor_generator = cfg.MODEL.ANCHOR_GENERATOR.NAME + return ANCHOR_GENERATOR_REGISTRY.get(anchor_generator)(cfg, input_shape) diff --git a/data_processing/detectron2/detectron2/modeling/backbone/__init__.py b/data_processing/detectron2/detectron2/modeling/backbone/__init__.py new file mode 100644 index 0000000..5b3358a --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import build_backbone, BACKBONE_REGISTRY # noqa F401 isort:skip + +from .backbone import Backbone +from .fpn import FPN +from .regnet import RegNet +from .resnet import ( + BasicStem, + ResNet, + ResNetBlockBase, + build_resnet_backbone, + make_stage, + BottleneckBlock, +) +from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate +from .mvit import MViT +from .swin import SwinTransformer + +__all__ = [k for k in globals().keys() if not k.startswith("_")] +# TODO can expose more resnet blocks after careful consideration diff --git a/data_processing/detectron2/detectron2/modeling/backbone/backbone.py b/data_processing/detectron2/detectron2/modeling/backbone/backbone.py new file mode 100644 index 0000000..e1c765a --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/backbone.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from abc import ABCMeta, abstractmethod +from typing import Dict +import torch.nn as nn + +from detectron2.layers import ShapeSpec + +__all__ = ["Backbone"] + + +class Backbone(nn.Module, metaclass=ABCMeta): + """ + Abstract base class for network backbones. + """ + + def __init__(self): + """ + The `__init__` method of any subclass can specify its own set of arguments. + """ + super().__init__() + + @abstractmethod + def forward(self): + """ + Subclasses must override this method, but adhere to the same return type. + + Returns: + dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor + """ + pass + + @property + def size_divisibility(self) -> int: + """ + Some backbones require the input height and width to be divisible by a + specific integer. This is typically true for encoder / decoder type networks + with lateral connection (e.g., FPN) for which feature maps need to match + dimension in the "bottom up" and "top down" paths. Set to 0 if no specific + input size divisibility is required. + """ + return 0 + + @property + def padding_constraints(self) -> Dict[str, int]: + """ + This property is a generalization of size_divisibility. Some backbones and training + recipes require specific padding constraints, such as enforcing divisibility by a specific + integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter + in :paper:vitdet). `padding_constraints` contains these optional items like: + { + "size_divisibility": int, + "square_size": int, + # Future options are possible + } + `size_divisibility` will read from here if presented and `square_size` indicates the + square padding size if `square_size` > 0. + + TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints + could be generalized as TypedDict (Python 3.8+) to support more types in the future. + """ + return {} + + def output_shape(self): + """ + Returns: + dict[str->ShapeSpec] + """ + # this is a backward-compatible default + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } diff --git a/data_processing/detectron2/detectron2/modeling/backbone/build.py b/data_processing/detectron2/detectron2/modeling/backbone/build.py new file mode 100644 index 0000000..af02141 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/build.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.layers import ShapeSpec +from detectron2.utils.registry import Registry + +from .backbone import Backbone + +BACKBONE_REGISTRY = Registry("BACKBONE") +BACKBONE_REGISTRY.__doc__ = """ +Registry for backbones, which extract feature maps from images + +The registered object must be a callable that accepts two arguments: + +1. A :class:`detectron2.config.CfgNode` +2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. + +Registered object must return instance of :class:`Backbone`. +""" + + +def build_backbone(cfg, input_shape=None): + """ + Build a backbone from `cfg.MODEL.BACKBONE.NAME`. + + Returns: + an instance of :class:`Backbone` + """ + if input_shape is None: + input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) + + backbone_name = cfg.MODEL.BACKBONE.NAME + backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) + assert isinstance(backbone, Backbone) + return backbone diff --git a/data_processing/detectron2/detectron2/modeling/backbone/fpn.py b/data_processing/detectron2/detectron2/modeling/backbone/fpn.py new file mode 100644 index 0000000..19d24e1 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/fpn.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.layers import Conv2d, ShapeSpec, get_norm + +from .backbone import Backbone +from .build import BACKBONE_REGISTRY +from .resnet import build_resnet_backbone + +__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"] + + +class FPN(Backbone): + """ + This module implements :paper:`FPN`. + It creates pyramid features built on top of some input feature maps. + """ + + _fuse_type: torch.jit.Final[str] + + def __init__( + self, + bottom_up, + in_features, + out_channels, + norm="", + top_block=None, + fuse_type="sum", + square_pad=0, + ): + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + norm (str): the normalization to use. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + FPN output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra FPN levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + fuse_type (str): types for fusing the top down features and the lateral + ones. It can be "sum" (default), which sums up element-wise; or "avg", + which takes the element-wise mean of the two. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(FPN, self).__init__() + assert isinstance(bottom_up, Backbone) + assert in_features, in_features + + # Feature map strides and channels from the bottom up network (e.g. ResNet) + input_shapes = bottom_up.output_shape() + strides = [input_shapes[f].stride for f in in_features] + in_channels_per_feature = [input_shapes[f].channels for f in in_features] + + _assert_strides_are_log2_contiguous(strides) + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(in_channels_per_feature): + lateral_norm = get_norm(norm, out_channels) + output_norm = get_norm(norm, out_channels) + + lateral_conv = Conv2d( + in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + stage = int(math.log2(strides[idx])) + self.add_module("fpn_lateral{}".format(stage), lateral_conv) + self.add_module("fpn_output{}".format(stage), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + self.top_block = top_block + self.in_features = tuple(in_features) + self.bottom_up = bottom_up + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + assert fuse_type in {"avg", "sum"} + self._fuse_type = fuse_type + + @property + def size_divisibility(self): + return self._size_divisibility + + @property + def padding_constraints(self): + return {"square_size": self._square_pad} + + def forward(self, x): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to + feature map tensor for each feature level in high to low resolution order. + + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + bottom_up_features = self.bottom_up(x) + results = [] + prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) + results.append(self.output_convs[0](prev_features)) + + # Reverse feature maps into top-down order (from low to high resolution) + for idx, (lateral_conv, output_conv) in enumerate( + zip(self.lateral_convs, self.output_convs) + ): + # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 + # Therefore we loop over all modules but skip the first one + if idx > 0: + features = self.in_features[-idx - 1] + features = bottom_up_features[features] + top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") + lateral_features = lateral_conv(features) + prev_features = lateral_features + top_down_features + if self._fuse_type == "avg": + prev_features /= 2 + results.insert(0, output_conv(prev_features)) + + if self.top_block is not None: + if self.top_block.in_feature in bottom_up_features: + top_block_in_feature = bottom_up_features[self.top_block.in_feature] + else: + top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] + results.extend(self.top_block(top_block_in_feature)) + assert len(self._out_features) == len(results) + return {f: res for f, res in zip(self._out_features, results)} + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + +def _assert_strides_are_log2_contiguous(strides): + """ + Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". + """ + for i, stride in enumerate(strides[1:], 1): + assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( + stride, strides[i - 1] + ) + + +class LastLevelMaxPool(nn.Module): + """ + This module is used in the original FPN to generate a downsampled + P6 feature from P5. + """ + + def __init__(self): + super().__init__() + self.num_levels = 1 + self.in_feature = "p5" + + def forward(self, x): + return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from + C5 feature. + """ + + def __init__(self, in_channels, out_channels, in_feature="res5"): + super().__init__() + self.num_levels = 2 + self.in_feature = in_feature + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelMaxPool(), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_p6p7 = bottom_up.output_shape()["res5"].channels + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_p6p7, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/data_processing/detectron2/detectron2/modeling/backbone/mvit.py b/data_processing/detectron2/detectron2/modeling/backbone/mvit.py new file mode 100644 index 0000000..50667a8 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/mvit.py @@ -0,0 +1,448 @@ +import logging +import numpy as np +import torch +import torch.nn as nn + +from .backbone import Backbone +from .utils import ( + PatchEmbed, + add_decomposed_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) + +logger = logging.getLogger(__name__) + + +__all__ = ["MViT"] + + +def attention_pool(x, pool, norm=None): + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H1, W1) -> (B, H1, W1, C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """Multiscale Multi-head Attention block.""" + + def __init__( + self, + dim, + dim_out, + num_heads, + qkv_bias=True, + norm_layer=nn.LayerNorm, + pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + residual_pooling=True, + window_size=0, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + pool_kernel (tuple): kernel size for qkv pooling layers. + stride_q (int): stride size for q pooling layer. + stride_kv (int): stride size for kv pooling layer. + residual_pooling (bool): If true, enable residual pooling. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # qkv pooling + pool_padding = [k // 2 for k in pool_kernel] + dim_conv = dim_out // num_heads + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_q, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_kv, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_kv, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + + self.window_size = window_size + if window_size: + self.q_win_size = window_size // stride_q + self.kv_win_size = window_size // stride_kv + self.residual_pooling = residual_pooling + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + assert input_size[0] == input_size[1] + size = input_size[0] + rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H, W, C) + qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, -1).permute(3, 0, 4, 1, 2, 5) + # q, k, v with shape (B * nHead, H, W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H, W, -1).unbind(0) + + q = attention_pool(q, self.pool_q, self.norm_q) + k = attention_pool(k, self.pool_k, self.norm_k) + v = attention_pool(v, self.pool_v, self.norm_v) + + ori_q = q + if self.window_size: + q, q_hw_pad = window_partition(q, self.q_win_size) + k, kv_hw_pad = window_partition(k, self.kv_win_size) + v, _ = window_partition(v, self.kv_win_size) + q_hw = (self.q_win_size, self.q_win_size) + kv_hw = (self.kv_win_size, self.kv_win_size) + else: + q_hw = q.shape[1:3] + kv_hw = k.shape[1:3] + + q = q.view(q.shape[0], np.prod(q_hw), -1) + k = k.view(k.shape[0], np.prod(kv_hw), -1) + v = v.view(v.shape[0], np.prod(kv_hw), -1) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, q_hw, kv_hw) + + attn = attn.softmax(dim=-1) + x = attn @ v + + x = x.view(x.shape[0], q_hw[0], q_hw[1], -1) + + if self.window_size: + x = window_unpartition(x, self.q_win_size, q_hw_pad, ori_q.shape[1:3]) + + if self.residual_pooling: + x += ori_q + + H, W = x.shape[1], x.shape[2] + x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """Multiscale Transformer blocks""" + + def __init__( + self, + dim, + dim_out, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + qkv_pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + residual_pooling=True, + window_size=0, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + num_heads (int): Number of attention heads in the MViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + stride_q (int): stride size for q pooling layer. + stride_kv (int): stride size for kv pooling layer. + residual_pooling (bool): If true, enable residual pooling. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + pool_kernel=qkv_pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + residual_pooling=residual_pooling, + window_size=window_size, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size, + ) + + from timm.models.layers import DropPath, Mlp + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp( + in_features=dim_out, + hidden_features=int(dim_out * mlp_ratio), + out_features=dim_out, + act_layer=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + if stride_q > 1: + kernel_skip = stride_q + 1 + padding_skip = int(kernel_skip // 2) + self.pool_skip = nn.MaxPool2d(kernel_skip, stride_q, padding_skip, ceil_mode=False) + + def forward(self, x): + x_norm = self.norm1(x) + x_block = self.attn(x_norm) + + if hasattr(self, "proj"): + x = self.proj(x_norm) + if hasattr(self, "pool_skip"): + x = attention_pool(x, self.pool_skip) + + x = x + self.drop_path(x_block) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class MViT(Backbone): + """ + This module implements Multiscale Vision Transformer (MViT) backbone in :paper:'mvitv2'. + """ + + def __init__( + self, + img_size=224, + patch_kernel=(7, 7), + patch_stride=(4, 4), + patch_padding=(3, 3), + in_chans=3, + embed_dim=96, + depth=16, + num_heads=1, + last_block_indexes=(0, 2, 11, 15), + qkv_pool_kernel=(3, 3), + adaptive_kv_stride=4, + adaptive_window_size=56, + residual_pooling=True, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=False, + use_rel_pos=True, + rel_pos_zero_init=True, + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_features=("scale2", "scale3", "scale4", "scale5"), + ): + """ + Args: + img_size (int): Input image size. + patch_kernel (tuple): kernel size for patch embedding. + patch_stride (tuple): stride size for patch embedding. + patch_padding (tuple): padding size for patch embedding. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of MViT. + num_heads (int): Number of base attention heads in each MViT block. + last_block_indexes (tuple): Block indexes for last blocks in each stage. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + adaptive_kv_stride (int): adaptive stride size for kv pooling. + adaptive_window_size (int): adaptive window size for window attention blocks. + residual_pooling (bool): If true, enable residual pooling. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + out_features (tuple): name of the feature maps from each stage. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_embed = PatchEmbed( + kernel_size=patch_kernel, + stride=patch_stride, + padding=patch_padding, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absoluate positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_stride[0]) * ( + pretrain_img_size // patch_stride[1] + ) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + dim_out = embed_dim + stride_kv = adaptive_kv_stride + window_size = adaptive_window_size + input_size = (img_size // patch_stride[0], img_size // patch_stride[1]) + stage = 2 + stride = patch_stride[0] + self._out_feature_strides = {} + self._out_feature_channels = {} + self.blocks = nn.ModuleList() + for i in range(depth): + # Multiply stride_kv by 2 if it's the last block of stage2 and stage3. + if i == last_block_indexes[1] or i == last_block_indexes[2]: + stride_kv_ = stride_kv * 2 + else: + stride_kv_ = stride_kv + # hybrid window attention: global attention in last three stages. + window_size_ = 0 if i in last_block_indexes[1:] else window_size + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + qkv_pool_kernel=qkv_pool_kernel, + stride_q=2 if i - 1 in last_block_indexes else 1, + stride_kv=stride_kv_, + residual_pooling=residual_pooling, + window_size=window_size_, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size, + ) + if use_act_checkpoint: + # TODO: use torch.utils.checkpoint + from fairscale.nn.checkpoint import checkpoint_wrapper + + block = checkpoint_wrapper(block) + self.blocks.append(block) + + embed_dim = dim_out + if i in last_block_indexes: + name = f"scale{stage}" + if name in out_features: + self._out_feature_channels[name] = dim_out + self._out_feature_strides[name] = stride + self.add_module(f"{name}_norm", norm_layer(dim_out)) + + dim_out *= 2 + num_heads *= 2 + stride_kv = max(stride_kv // 2, 1) + stride *= 2 + stage += 1 + if i - 1 in last_block_indexes: + window_size = window_size // 2 + input_size = [s // 2 for s in input_size] + + self._out_features = out_features + self._last_block_indexes = last_block_indexes + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + + if self.pos_embed is not None: + x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token, x.shape[1:3]) + + outputs = {} + stage = 2 + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in self._last_block_indexes: + name = f"scale{stage}" + if name in self._out_features: + x_out = getattr(self, f"{name}_norm")(x) + outputs[name] = x_out.permute(0, 3, 1, 2) + stage += 1 + + return outputs diff --git a/data_processing/detectron2/detectron2/modeling/backbone/regnet.py b/data_processing/detectron2/detectron2/modeling/backbone/regnet.py new file mode 100644 index 0000000..3533d63 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/regnet.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Implementation of RegNet models from :paper:`dds` and :paper:`scaling`. + +This code is adapted from https://github.com/facebookresearch/pycls with minimal modifications. +Some code duplication exists between RegNet and ResNets (e.g., ResStem) in order to simplify +model loading. +""" + +import numpy as np +from torch import nn + +from detectron2.layers import CNNBlockBase, ShapeSpec, get_norm + +from .backbone import Backbone + +__all__ = [ + "AnyNet", + "RegNet", + "ResStem", + "SimpleStem", + "VanillaBlock", + "ResBasicBlock", + "ResBottleneckBlock", +] + + +def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False): + """Helper for building a conv2d layer.""" + assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." + s, p, g, b = stride, (k - 1) // 2, groups, bias + return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b) + + +def gap2d(): + """Helper for building a global average pooling layer.""" + return nn.AdaptiveAvgPool2d((1, 1)) + + +def pool2d(k, *, stride=1): + """Helper for building a pool2d layer.""" + assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." + return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2) + + +def init_weights(m): + """Performs ResNet-style weight initialization.""" + if isinstance(m, nn.Conv2d): + # Note that there is no bias due to BN + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(mean=0.0, std=0.01) + m.bias.data.zero_() + + +class ResStem(CNNBlockBase): + """ResNet stem for ImageNet: 7x7, BN, AF, MaxPool.""" + + def __init__(self, w_in, w_out, norm, activation_class): + super().__init__(w_in, w_out, 4) + self.conv = conv2d(w_in, w_out, 7, stride=2) + self.bn = get_norm(norm, w_out) + self.af = activation_class() + self.pool = pool2d(3, stride=2) + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class SimpleStem(CNNBlockBase): + """Simple stem for ImageNet: 3x3, BN, AF.""" + + def __init__(self, w_in, w_out, norm, activation_class): + super().__init__(w_in, w_out, 2) + self.conv = conv2d(w_in, w_out, 3, stride=2) + self.bn = get_norm(norm, w_out) + self.af = activation_class() + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class SE(nn.Module): + """Squeeze-and-Excitation (SE) block: AvgPool, FC, Act, FC, Sigmoid.""" + + def __init__(self, w_in, w_se, activation_class): + super().__init__() + self.avg_pool = gap2d() + self.f_ex = nn.Sequential( + conv2d(w_in, w_se, 1, bias=True), + activation_class(), + conv2d(w_se, w_in, 1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x): + return x * self.f_ex(self.avg_pool(x)) + + +class VanillaBlock(CNNBlockBase): + """Vanilla block: [3x3 conv, BN, Relu] x2.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, _params): + super().__init__(w_in, w_out, stride) + self.a = conv2d(w_in, w_out, 3, stride=stride) + self.a_bn = get_norm(norm, w_out) + self.a_af = activation_class() + self.b = conv2d(w_out, w_out, 3) + self.b_bn = get_norm(norm, w_out) + self.b_af = activation_class() + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class BasicTransform(nn.Module): + """Basic transformation: [3x3 conv, BN, Relu] x2.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, _params): + super().__init__() + self.a = conv2d(w_in, w_out, 3, stride=stride) + self.a_bn = get_norm(norm, w_out) + self.a_af = activation_class() + self.b = conv2d(w_out, w_out, 3) + self.b_bn = get_norm(norm, w_out) + self.b_bn.final_bn = True + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class ResBasicBlock(CNNBlockBase): + """Residual basic block: x + f(x), f = basic transform.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__(w_in, w_out, stride) + self.proj, self.bn = None, None + if (w_in != w_out) or (stride != 1): + self.proj = conv2d(w_in, w_out, 1, stride=stride) + self.bn = get_norm(norm, w_out) + self.f = BasicTransform(w_in, w_out, stride, norm, activation_class, params) + self.af = activation_class() + + def forward(self, x): + x_p = self.bn(self.proj(x)) if self.proj else x + return self.af(x_p + self.f(x)) + + +class BottleneckTransform(nn.Module): + """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__() + w_b = int(round(w_out * params["bot_mul"])) + w_se = int(round(w_in * params["se_r"])) + groups = w_b // params["group_w"] + self.a = conv2d(w_in, w_b, 1) + self.a_bn = get_norm(norm, w_b) + self.a_af = activation_class() + self.b = conv2d(w_b, w_b, 3, stride=stride, groups=groups) + self.b_bn = get_norm(norm, w_b) + self.b_af = activation_class() + self.se = SE(w_b, w_se, activation_class) if w_se else None + self.c = conv2d(w_b, w_out, 1) + self.c_bn = get_norm(norm, w_out) + self.c_bn.final_bn = True + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class ResBottleneckBlock(CNNBlockBase): + """Residual bottleneck block: x + f(x), f = bottleneck transform.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__(w_in, w_out, stride) + self.proj, self.bn = None, None + if (w_in != w_out) or (stride != 1): + self.proj = conv2d(w_in, w_out, 1, stride=stride) + self.bn = get_norm(norm, w_out) + self.f = BottleneckTransform(w_in, w_out, stride, norm, activation_class, params) + self.af = activation_class() + + def forward(self, x): + x_p = self.bn(self.proj(x)) if self.proj else x + return self.af(x_p + self.f(x)) + + +class AnyStage(nn.Module): + """AnyNet stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, w_in, w_out, stride, d, block_class, norm, activation_class, params): + super().__init__() + for i in range(d): + block = block_class(w_in, w_out, stride, norm, activation_class, params) + self.add_module("b{}".format(i + 1), block) + stride, w_in = 1, w_out + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class AnyNet(Backbone): + """AnyNet model. See :paper:`dds`.""" + + def __init__( + self, + *, + stem_class, + stem_width, + block_class, + depths, + widths, + group_widths, + strides, + bottleneck_ratios, + se_ratio, + activation_class, + freeze_at=0, + norm="BN", + out_features=None, + ): + """ + Args: + stem_class (callable): A callable taking 4 arguments (channels in, channels out, + normalization, callable returning an activation function) that returns another + callable implementing the stem module. + stem_width (int): The number of output channels that the stem produces. + block_class (callable): A callable taking 6 arguments (channels in, channels out, + stride, normalization, callable returning an activation function, a dict of + block-specific parameters) that returns another callable implementing the repeated + block module. + depths (list[int]): Number of blocks in each stage. + widths (list[int]): For each stage, the number of output channels of each block. + group_widths (list[int]): For each stage, the number of channels per group in group + convolution, if the block uses group convolution. + strides (list[int]): The stride that each network stage applies to its input. + bottleneck_ratios (list[float]): For each stage, the ratio of the number of bottleneck + channels to the number of block input channels (or, equivalently, output channels), + if the block uses a bottleneck. + se_ratio (float): The ratio of the number of channels used inside the squeeze-excitation + (SE) module to it number of input channels, if SE the block uses SE. + activation_class (callable): A callable taking no arguments that returns another + callable implementing an activation function. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. RegNet's use "stem" and "s1", "s2", etc for the stages after + the stem. If None, will return the output of the last layer. + """ + super().__init__() + self.stem = stem_class(3, stem_width, norm, activation_class) + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + self.stages_and_names = [] + prev_w = stem_width + + for i, (d, w, s, b, g) in enumerate( + zip(depths, widths, strides, bottleneck_ratios, group_widths) + ): + params = {"bot_mul": b, "group_w": g, "se_r": se_ratio} + stage = AnyStage(prev_w, w, s, d, block_class, norm, activation_class, params) + name = "s{}".format(i + 1) + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in stage.children()]) + ) + self._out_feature_channels[name] = list(stage.children())[-1].out_channels + prev_w = w + + self.apply(init_weights) + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {} does not include {}".format( + ", ".join(children), out_feature + ) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"Model takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the model. Commonly used in fine-tuning. + + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + + Returns: + nn.Module: this model itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, (stage, _) in enumerate(self.stages_and_names, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + +def adjust_block_compatibility(ws, bs, gs): + """Adjusts the compatibility of widths, bottlenecks, and groups.""" + assert len(ws) == len(bs) == len(gs) + assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs)) + vs = [int(max(1, w * b)) for w, b in zip(ws, bs)] + gs = [int(min(g, v)) for g, v in zip(gs, vs)] + ms = [np.lcm(g, b) if b > 1 else g for g, b in zip(gs, bs)] + vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)] + ws = [int(v / b) for v, b in zip(vs, bs)] + assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs)) + return ws, bs, gs + + +def generate_regnet_parameters(w_a, w_0, w_m, d, q=8): + """Generates per stage widths and depths from RegNet parameters.""" + assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 + # Generate continuous per-block ws + ws_cont = np.arange(d) * w_a + w_0 + # Generate quantized per-block ws + ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) + ws_all = w_0 * np.power(w_m, ks) + ws_all = np.round(np.divide(ws_all, q)).astype(int) * q + # Generate per stage ws and ds (assumes ws_all are sorted) + ws, ds = np.unique(ws_all, return_counts=True) + # Compute number of actual stages and total possible stages + num_stages, total_stages = len(ws), ks.max() + 1 + # Convert numpy arrays to lists and return + ws, ds, ws_all, ws_cont = (x.tolist() for x in (ws, ds, ws_all, ws_cont)) + return ws, ds, num_stages, total_stages, ws_all, ws_cont + + +class RegNet(AnyNet): + """RegNet model. See :paper:`dds`.""" + + def __init__( + self, + *, + stem_class, + stem_width, + block_class, + depth, + w_a, + w_0, + w_m, + group_width, + stride=2, + bottleneck_ratio=1.0, + se_ratio=0.0, + activation_class=None, + freeze_at=0, + norm="BN", + out_features=None, + ): + """ + Build a RegNet from the parameterization described in :paper:`dds` Section 3.3. + + Args: + See :class:`AnyNet` for arguments that are not listed here. + depth (int): Total number of blocks in the RegNet. + w_a (float): Factor by which block width would increase prior to quantizing block widths + by stage. See :paper:`dds` Section 3.3. + w_0 (int): Initial block width. See :paper:`dds` Section 3.3. + w_m (float): Parameter controlling block width quantization. + See :paper:`dds` Section 3.3. + group_width (int): Number of channels per group in group convolution, if the block uses + group convolution. + bottleneck_ratio (float): The ratio of the number of bottleneck channels to the number + of block input channels (or, equivalently, output channels), if the block uses a + bottleneck. + stride (int): The stride that each network stage applies to its input. + """ + ws, ds = generate_regnet_parameters(w_a, w_0, w_m, depth)[0:2] + ss = [stride for _ in ws] + bs = [bottleneck_ratio for _ in ws] + gs = [group_width for _ in ws] + ws, bs, gs = adjust_block_compatibility(ws, bs, gs) + + def default_activation_class(): + return nn.ReLU(inplace=True) + + super().__init__( + stem_class=stem_class, + stem_width=stem_width, + block_class=block_class, + depths=ds, + widths=ws, + strides=ss, + group_widths=gs, + bottleneck_ratios=bs, + se_ratio=se_ratio, + activation_class=default_activation_class + if activation_class is None + else activation_class, + freeze_at=freeze_at, + norm=norm, + out_features=out_features, + ) diff --git a/data_processing/detectron2/detectron2/modeling/backbone/resnet.py b/data_processing/detectron2/detectron2/modeling/backbone/resnet.py new file mode 100644 index 0000000..5b8e842 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/resnet.py @@ -0,0 +1,694 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.layers import ( + CNNBlockBase, + Conv2d, + DeformConv, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) + +from .backbone import Backbone +from .build import BACKBONE_REGISTRY + +__all__ = [ + "ResNetBlockBase", + "BasicBlock", + "BottleneckBlock", + "DeformBottleneckBlock", + "BasicStem", + "ResNet", + "make_stage", + "build_resnet_backbone", +] + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, + with two 3x3 conv layers and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(CNNBlockBase): + """ + Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv ` + in the 3x3 convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + deform_modulated=False, + deform_num_groups=1, + ): + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + self.conv2_offset = Conv2d( + bottleneck_channels, + offset_channels * deform_num_groups, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + dilation=dilation, + ) + self.conv2 = deform_conv_op( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + nn.init.constant_(self.conv2_offset.weight, 0) + nn.init.constant_(self.conv2_offset.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + if self.deform_modulated: + offset_mask = self.conv2_offset(out) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block), + with a conv, relu and max_pool. + """ + + def __init__(self, in_channels=3, out_channels=64, norm="BN"): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class ResNet(Backbone): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] + ) + stages = stages[:num_stages] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + + Returns: + list[CNNBlockBase]: a list of block module. + + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret + + +ResNetBlockBase = CNNBlockBase +""" +Alias for backward compatibiltiy. +""" + + +def make_stage(*args, **kwargs): + """ + Deprecated alias for backward compatibiltiy. + """ + return ResNet.make_stage(*args, **kwargs) + + +@BACKBONE_REGISTRY.register() +def build_resnet_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert not any( + deform_on_per_stage + ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" + assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + for idx, stage_idx in enumerate(range(2, 6)): + # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at) diff --git a/data_processing/detectron2/detectron2/modeling/backbone/swin.py b/data_processing/detectron2/detectron2/modeling/backbone/swin.py new file mode 100644 index 0000000..780b6fc --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/swin.py @@ -0,0 +1,695 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Implementation of Swin models from :paper:`swin`. + +This code is adapted from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py with minimal modifications. # noqa +-------------------------------------------------------- +Swin Transformer +Copyright (c) 2021 Microsoft +Licensed under The MIT License [see LICENSE for details] +Written by Ze Liu, Yutong Lin, Yixuan Wei +-------------------------------------------------------- +LICENSE: https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/461e003166a8083d0b620beacd4662a2df306bd6/LICENSE +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from detectron2.modeling.backbone.backbone import Backbone + +_to_2tuple = nn.modules.utils._ntuple(2) + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. + Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=_to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + if drop_path > 0.0: + from timm.models.layers import DropPath + + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. + Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = _to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(Backbone): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted + Windows` - https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = _to_2tuple(pretrain_img_size) + patch_size = _to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + nn.init.trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + self._out_features = ["p{}".format(i) for i in self.out_indices] + self._out_feature_channels = { + "p{}".format(i): self.embed_dim * 2**i for i in self.out_indices + } + self._out_feature_strides = {"p{}".format(i): 2 ** (i + 2) for i in self.out_indices} + self._size_devisibility = 32 + + self.apply(self._init_weights) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["p{}".format(i)] = out + + return outs diff --git a/data_processing/detectron2/detectron2/modeling/backbone/utils.py b/data_processing/detectron2/detectron2/modeling/backbone/utils.py new file mode 100644 index 0000000..2b89a4c --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "window_partition", + "window_unpartition", + "add_decomposed_rel_pos", + "get_abs_pos", + "PatchEmbed", +] + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +def get_abs_pos(abs_pos, has_cls_token, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/data_processing/detectron2/detectron2/modeling/backbone/vit.py b/data_processing/detectron2/detectron2/modeling/backbone/vit.py new file mode 100644 index 0000000..31cc28a --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/backbone/vit.py @@ -0,0 +1,524 @@ +import logging +import math +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn as nn + +from detectron2.layers import CNNBlockBase, Conv2d, get_norm +from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous + +from .backbone import Backbone +from .utils import ( + PatchEmbed, + add_decomposed_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) + +logger = logging.getLogger(__name__) + + +__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"] + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class ResBottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block without the last activation layer. + It contains 3 conv layers with kernels 1x1, 3x3, 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + norm="LN", + act_layer=nn.GELU, + ): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + act_layer (callable): activation for all conv layers. + """ + super().__init__(in_channels, out_channels, 1) + + self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = get_norm(norm, bottleneck_channels) + self.act1 = act_layer() + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + 3, + padding=1, + bias=False, + ) + self.norm2 = get_norm(norm, bottleneck_channels) + self.act2 = act_layer() + + self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = get_norm(norm, out_channels) + + for layer in [self.conv1, self.conv2, self.conv3]: + weight_init.c2_msra_fill(layer) + for layer in [self.norm1, self.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + self.norm3.weight.data.zero_() + self.norm3.bias.data.zero_() + + def forward(self, x): + out = x + for layer in self.children(): + out = layer(out) + + out = x + out + return out + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + use_residual_block=False, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_residual_block (bool): If True, use a residual block after the MLP block. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + from timm.models.layers import DropPath, Mlp + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + self.window_size = window_size + + self.use_residual_block = use_residual_block + if use_residual_block: + # Use a residual block with bottleneck channel as dim // 2 + self.residual = ResBottleneckBlock( + in_channels=dim, + out_channels=dim, + bottleneck_channels=dim // 2, + norm="LN", + act_layer=act_layer, + ) + + def forward(self, x): + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if self.use_residual_block: + x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + return x + + +class ViT(Backbone): + """ + This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 + """ + + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=True, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + window_block_indexes=(), + residual_block_indexes=(), + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_feature="last_feat", + ): + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + window_block_indexes (list): Indexes for blocks using window attention. + residual_block_indexes (list): Indexes for blocks using conv propagation. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + out_feature (str): name of the feature from the last block. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + use_residual_block=i in residual_block_indexes, + input_size=(img_size // patch_size, img_size // patch_size), + ) + if use_act_checkpoint: + # TODO: use torch.utils.checkpoint + from fairscale.nn.checkpoint import checkpoint_wrapper + + block = checkpoint_wrapper(block) + self.blocks.append(block) + + self._out_feature_channels = {out_feature: embed_dim} + self._out_feature_strides = {out_feature: patch_size} + self._out_features = [out_feature] + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) + ) + + for blk in self.blocks: + x = blk(x) + + outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)} + return outputs + + +class SimpleFeaturePyramid(Backbone): + """ + This module implements SimpleFeaturePyramid in :paper:`vitdet`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + net, + in_feature, + out_channels, + scale_factors, + top_block=None, + norm="LN", + square_pad=0, + ): + """ + Args: + net (Backbone): module representing the subnetwork backbone. + Must be a subclass of :class:`Backbone`. + in_feature (str): names of the input feature maps coming + from the net. + out_channels (int): number of channels in the output feature maps. + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + pyramid output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra pyramid levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + norm (str): the normalization to use. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(SimpleFeaturePyramid, self).__init__() + assert isinstance(net, Backbone) + + self.scale_factors = scale_factors + + input_shapes = net.output_shape() + strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] + _assert_strides_are_log2_contiguous(strides) + + dim = input_shapes[in_feature].channels + self.stages = [] + use_bias = norm == "" + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + get_norm(norm, dim // 2), + nn.GELU(), + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + layers.extend( + [ + Conv2d( + out_dim, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + ] + ) + layers = nn.Sequential(*layers) + + stage = int(math.log2(strides[idx])) + self.add_module(f"simfp_{stage}", layers) + self.stages.append(layers) + + self.net = net + self.in_feature = in_feature + self.top_block = top_block + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + + @property + def padding_constraints(self): + return { + "size_divisiblity": self._size_divisibility, + "square_size": self._square_pad, + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: + mapping from feature map name to pyramid feature map tensor + in high to low resolution order. Returned feature names follow the FPN + convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + bottom_up_features = self.net(x) + features = bottom_up_features[self.in_feature] + results = [] + + for stage in self.stages: + results.append(stage(features)) + + if self.top_block is not None: + if self.top_block.in_feature in bottom_up_features: + top_block_in_feature = bottom_up_features[self.top_block.in_feature] + else: + top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] + results.extend(self.top_block(top_block_in_feature)) + assert len(self._out_features) == len(results) + return {f: res for f, res in zip(self._out_features, results)} + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) diff --git a/data_processing/detectron2/detectron2/modeling/box_regression.py b/data_processing/detectron2/detectron2/modeling/box_regression.py new file mode 100644 index 0000000..b24c123 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/box_regression.py @@ -0,0 +1,369 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Tuple, Union +import torch +from fvcore.nn import giou_loss, smooth_l1_loss +from torch.nn import functional as F + +from detectron2.layers import cat, ciou_loss, diou_loss +from detectron2.structures import Boxes + +# Value for clamping large dw and dh predictions. The heuristic is that we clamp +# such that dw and dh are no larger than what would transform a 16px box into a +# 1000px box (based on a small anchor, 16px, and a typical image size, 1000px). +_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16) + + +__all__ = ["Box2BoxTransform", "Box2BoxTransformRotated", "Box2BoxTransformLinear"] + + +@torch.jit.script +class Box2BoxTransform(object): + """ + The box-to-box transform defined in R-CNN. The transformation is parameterized + by 4 deltas: (dx, dy, dw, dh). The transformation scales the box's width and height + by exp(dw), exp(dh) and shifts a box's center by the offset (dx * width, dy * height). + """ + + def __init__( + self, weights: Tuple[float, float, float, float], scale_clamp: float = _DEFAULT_SCALE_CLAMP + ): + """ + Args: + weights (4-element tuple): Scaling factors that are applied to the + (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set + such that the deltas have unit variance; now they are treated as + hyperparameters of the system. + scale_clamp (float): When predicting deltas, the predicted box scaling + factors (dw and dh) are clamped such that they are <= scale_clamp. + """ + self.weights = weights + self.scale_clamp = scale_clamp + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx, dy, dw, dh) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless + any delta is too large and is clamped). + + Args: + src_boxes (Tensor): source boxes, e.g., object proposals + target_boxes (Tensor): target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_widths = src_boxes[:, 2] - src_boxes[:, 0] + src_heights = src_boxes[:, 3] - src_boxes[:, 1] + src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths + src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights + + target_widths = target_boxes[:, 2] - target_boxes[:, 0] + target_heights = target_boxes[:, 3] - target_boxes[:, 1] + target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths + target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights + + wx, wy, ww, wh = self.weights + dx = wx * (target_ctr_x - src_ctr_x) / src_widths + dy = wy * (target_ctr_y - src_ctr_y) / src_heights + dw = ww * torch.log(target_widths / src_widths) + dh = wh * torch.log(target_heights / src_heights) + + deltas = torch.stack((dx, dy, dw, dh), dim=1) + assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!" + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. + deltas[i] represents k potentially different class-specific + box transformations for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 4) + """ + deltas = deltas.float() # ensure fp32 for decoding precision + boxes = boxes.to(deltas.dtype) + + widths = boxes[:, 2] - boxes[:, 0] + heights = boxes[:, 3] - boxes[:, 1] + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = deltas[:, 0::4] / wx + dy = deltas[:, 1::4] / wy + dw = deltas[:, 2::4] / ww + dh = deltas[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.scale_clamp) + dh = torch.clamp(dh, max=self.scale_clamp) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + x1 = pred_ctr_x - 0.5 * pred_w + y1 = pred_ctr_y - 0.5 * pred_h + x2 = pred_ctr_x + 0.5 * pred_w + y2 = pred_ctr_y + 0.5 * pred_h + pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1) + return pred_boxes.reshape(deltas.shape) + + +@torch.jit.script +class Box2BoxTransformRotated(object): + """ + The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized + by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box's width and height + by exp(dw), exp(dh), shifts a box's center by the offset (dx * width, dy * height), + and rotate a box's angle by da (radians). + Note: angles of deltas are in radians while angles of boxes are in degrees. + """ + + def __init__( + self, + weights: Tuple[float, float, float, float, float], + scale_clamp: float = _DEFAULT_SCALE_CLAMP, + ): + """ + Args: + weights (5-element tuple): Scaling factors that are applied to the + (dx, dy, dw, dh, da) deltas. These are treated as + hyperparameters of the system. + scale_clamp (float): When predicting deltas, the predicted box scaling + factors (dw and dh) are clamped such that they are <= scale_clamp. + """ + self.weights = weights + self.scale_clamp = scale_clamp + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless + any delta is too large and is clamped). + + Args: + src_boxes (Tensor): Nx5 source boxes, e.g., object proposals + target_boxes (Tensor): Nx5 target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_ctr_x, src_ctr_y, src_widths, src_heights, src_angles = torch.unbind(src_boxes, dim=1) + + target_ctr_x, target_ctr_y, target_widths, target_heights, target_angles = torch.unbind( + target_boxes, dim=1 + ) + + wx, wy, ww, wh, wa = self.weights + dx = wx * (target_ctr_x - src_ctr_x) / src_widths + dy = wy * (target_ctr_y - src_ctr_y) / src_heights + dw = ww * torch.log(target_widths / src_widths) + dh = wh * torch.log(target_heights / src_heights) + # Angles of deltas are in radians while angles of boxes are in degrees. + # the conversion to radians serve as a way to normalize the values + da = target_angles - src_angles + da = (da + 180.0) % 360.0 - 180.0 # make it in [-180, 180) + da *= wa * math.pi / 180.0 + + deltas = torch.stack((dx, dy, dw, dh, da), dim=1) + assert ( + (src_widths > 0).all().item() + ), "Input boxes to Box2BoxTransformRotated are not valid!" + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx, dy, dw, dh, da) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*5). + deltas[i] represents box transformation for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 5) + """ + assert deltas.shape[1] % 5 == 0 and boxes.shape[1] == 5 + + boxes = boxes.to(deltas.dtype).unsqueeze(2) + + ctr_x = boxes[:, 0] + ctr_y = boxes[:, 1] + widths = boxes[:, 2] + heights = boxes[:, 3] + angles = boxes[:, 4] + + wx, wy, ww, wh, wa = self.weights + + dx = deltas[:, 0::5] / wx + dy = deltas[:, 1::5] / wy + dw = deltas[:, 2::5] / ww + dh = deltas[:, 3::5] / wh + da = deltas[:, 4::5] / wa + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.scale_clamp) + dh = torch.clamp(dh, max=self.scale_clamp) + + pred_boxes = torch.zeros_like(deltas) + pred_boxes[:, 0::5] = dx * widths + ctr_x # x_ctr + pred_boxes[:, 1::5] = dy * heights + ctr_y # y_ctr + pred_boxes[:, 2::5] = torch.exp(dw) * widths # width + pred_boxes[:, 3::5] = torch.exp(dh) * heights # height + + # Following original RRPN implementation, + # angles of deltas are in radians while angles of boxes are in degrees. + pred_angle = da * 180.0 / math.pi + angles + pred_angle = (pred_angle + 180.0) % 360.0 - 180.0 # make it in [-180, 180) + + pred_boxes[:, 4::5] = pred_angle + + return pred_boxes + + +class Box2BoxTransformLinear(object): + """ + The linear box-to-box transform defined in FCOS. The transformation is parameterized + by the distance from the center of (square) src box to 4 edges of the target box. + """ + + def __init__(self, normalize_by_size=True): + """ + Args: + normalize_by_size: normalize deltas by the size of src (anchor) boxes. + """ + self.normalize_by_size = normalize_by_size + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx1, dy1, dx2, dy2) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true. + The center of src must be inside target boxes. + + Args: + src_boxes (Tensor): square source boxes, e.g., anchors + target_boxes (Tensor): target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_ctr_x = 0.5 * (src_boxes[:, 0] + src_boxes[:, 2]) + src_ctr_y = 0.5 * (src_boxes[:, 1] + src_boxes[:, 3]) + + target_l = src_ctr_x - target_boxes[:, 0] + target_t = src_ctr_y - target_boxes[:, 1] + target_r = target_boxes[:, 2] - src_ctr_x + target_b = target_boxes[:, 3] - src_ctr_y + + deltas = torch.stack((target_l, target_t, target_r, target_b), dim=1) + if self.normalize_by_size: + stride_w = src_boxes[:, 2] - src_boxes[:, 0] + stride_h = src_boxes[:, 3] - src_boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + deltas = deltas / strides + + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx1, dy1, dx2, dy2) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. + deltas[i] represents k potentially different class-specific + box transformations for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 4) + """ + # Ensure the output is a valid box. See Sec 2.1 of https://arxiv.org/abs/2006.09214 + deltas = F.relu(deltas) + boxes = boxes.to(deltas.dtype) + + ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) + ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) + if self.normalize_by_size: + stride_w = boxes[:, 2] - boxes[:, 0] + stride_h = boxes[:, 3] - boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + deltas = deltas * strides + + l = deltas[:, 0::4] + t = deltas[:, 1::4] + r = deltas[:, 2::4] + b = deltas[:, 3::4] + + pred_boxes = torch.zeros_like(deltas) + pred_boxes[:, 0::4] = ctr_x[:, None] - l # x1 + pred_boxes[:, 1::4] = ctr_y[:, None] - t # y1 + pred_boxes[:, 2::4] = ctr_x[:, None] + r # x2 + pred_boxes[:, 3::4] = ctr_y[:, None] + b # y2 + return pred_boxes + + +def _dense_box_regression_loss( + anchors: List[Union[Boxes, torch.Tensor]], + box2box_transform: Box2BoxTransform, + pred_anchor_deltas: List[torch.Tensor], + gt_boxes: List[torch.Tensor], + fg_mask: torch.Tensor, + box_reg_loss_type="smooth_l1", + smooth_l1_beta=0.0, +): + """ + Compute loss for dense multi-level box regression. + Loss is accumulated over ``fg_mask``. + + Args: + anchors: #lvl anchor boxes, each is (HixWixA, 4) + pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4) + gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A)) + fg_mask: the foreground boolean mask of shape (N, R) to compute loss on + box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou", + "diou", "ciou". + smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to + use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" + """ + if isinstance(anchors[0], Boxes): + anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) + else: + anchors = cat(anchors) + if box_reg_loss_type == "smooth_l1": + gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes] + gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) + loss_box_reg = smooth_l1_loss( + cat(pred_anchor_deltas, dim=1)[fg_mask], + gt_anchor_deltas[fg_mask], + beta=smooth_l1_beta, + reduction="sum", + ) + elif box_reg_loss_type == "giou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = giou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + elif box_reg_loss_type == "diou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = diou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + elif box_reg_loss_type == "ciou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = ciou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + else: + raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'") + return loss_box_reg diff --git a/data_processing/detectron2/detectron2/modeling/matcher.py b/data_processing/detectron2/detectron2/modeling/matcher.py new file mode 100644 index 0000000..c7597ca --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/matcher.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch + +from detectron2.layers import nonzero_tuple + + +# TODO: the name is too general +class Matcher(object): + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be matched to zero or more predicted elements. + + The matching is determined by the MxN match_quality_matrix, that characterizes + how well each (ground-truth, prediction)-pair match each other. For example, + if the elements are boxes, this matrix may contain box intersection-over-union + overlap values. + + The matcher returns (a) a vector of length N containing the index of the + ground-truth element m in [0, M) that matches to prediction n in [0, N). + (b) a vector of length N containing the labels for each prediction. + """ + + def __init__( + self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False + ): + """ + Args: + thresholds (list): a list of thresholds used to stratify predictions + into levels. + labels (list): a list of values to label predictions belonging at + each level. A label can be one of {-1, 0, 1} signifying + {ignore, negative class, positive class}, respectively. + allow_low_quality_matches (bool): if True, produce additional matches + for predictions with maximum match quality lower than high_threshold. + See set_low_quality_matches_ for more details. + + For example, + thresholds = [0.3, 0.5] + labels = [0, -1, 1] + All predictions with iou < 0.3 will be marked with 0 and + thus will be considered as false positives while training. + All predictions with 0.3 <= iou < 0.5 will be marked with -1 and + thus will be ignored. + All predictions with 0.5 <= iou will be marked with 1 and + thus will be considered as true positives. + """ + # Add -inf and +inf to first and last position in thresholds + thresholds = thresholds[:] + assert thresholds[0] > 0 + thresholds.insert(0, -float("inf")) + thresholds.append(float("inf")) + # Currently torchscript does not support all + generator + assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) + assert all([l in [-1, 0, 1] for l in labels]) + assert len(labels) == len(thresholds) - 1 + self.thresholds = thresholds + self.labels = labels + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix): + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted + elements. All elements must be >= 0 (due to the us of `torch.nonzero` + for selecting indices in :meth:`set_low_quality_matches_`). + + Returns: + matches (Tensor[int64]): a vector of length N, where matches[i] is a matched + ground-truth index in [0, M) + match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates + whether a prediction is a true or false positive or ignored + """ + assert match_quality_matrix.dim() == 2 + if match_quality_matrix.numel() == 0: + default_matches = match_quality_matrix.new_full( + (match_quality_matrix.size(1),), 0, dtype=torch.int64 + ) + # When no gt boxes exist, we define IOU = 0 and therefore set labels + # to `self.labels[0]`, which usually defaults to background class 0 + # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds + default_match_labels = match_quality_matrix.new_full( + (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 + ) + return default_matches, default_match_labels + + assert torch.all(match_quality_matrix >= 0) + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + + match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) + + for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): + low_high = (matched_vals >= low) & (matched_vals < high) + match_labels[low_high] = l + + if self.allow_low_quality_matches: + self.set_low_quality_matches_(match_labels, match_quality_matrix) + + return matches, match_labels + + def set_low_quality_matches_(self, match_labels, match_quality_matrix): + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth G find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth G. + + This function implements the RPN assignment case (i) in Sec. 3.1.2 of + :paper:`Faster R-CNN`. + """ + # For each gt, find the prediction with which it has highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find the highest quality match available, even if it is low, including ties. + # Note that the matches qualities must be positive due to the use of + # `torch.nonzero`. + _, pred_inds_with_highest_quality = nonzero_tuple( + match_quality_matrix == highest_quality_foreach_gt[:, None] + ) + # If an anchor was labeled positive only due to a low-quality match + # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B. + # This follows the implementation in Detectron, and is found to have no significant impact. + match_labels[pred_inds_with_highest_quality] = 1 diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/__init__.py b/data_processing/detectron2/detectron2/modeling/meta_arch/__init__.py new file mode 100644 index 0000000..6b06681 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from .build import META_ARCH_REGISTRY, build_model # isort:skip + +from .panoptic_fpn import PanopticFPN + +# import all the meta_arch, so they will be registered +from .rcnn import GeneralizedRCNN, ProposalNetwork +from .dense_detector import DenseDetector +from .retinanet import RetinaNet +from .fcos import FCOS +from .semantic_seg import SEM_SEG_HEADS_REGISTRY, SemanticSegmentor, build_sem_seg_head + + +__all__ = list(globals().keys()) diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/build.py b/data_processing/detectron2/detectron2/modeling/meta_arch/build.py new file mode 100644 index 0000000..3427215 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/build.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from detectron2.utils.logger import _log_api_usage +from detectron2.utils.registry import Registry + +META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip +META_ARCH_REGISTRY.__doc__ = """ +Registry for meta-architectures, i.e. the whole model. + +The registered object will be called with `obj(cfg)` +and expected to return a `nn.Module` object. +""" + + +def build_model(cfg): + """ + Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. + Note that it does not load any weights from ``cfg``. + """ + meta_arch = cfg.MODEL.META_ARCHITECTURE + model = META_ARCH_REGISTRY.get(meta_arch)(cfg) + model.to(torch.device(cfg.MODEL.DEVICE)) + _log_api_usage("modeling.meta_arch." + meta_arch) + return model diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/dense_detector.py b/data_processing/detectron2/detectron2/modeling/meta_arch/dense_detector.py new file mode 100644 index 0000000..33066b6 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/dense_detector.py @@ -0,0 +1,294 @@ +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import Tensor, nn + +from detectron2.data.detection_utils import convert_image_to_rgb +from detectron2.layers import move_device_like +from detectron2.modeling import Backbone +from detectron2.structures import Boxes, ImageList, Instances +from detectron2.utils.events import get_event_storage + +from ..postprocessing import detector_postprocess + + +def permute_to_N_HWA_K(tensor, K: int): + """ + Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K) + """ + assert tensor.dim() == 4, tensor.shape + N, _, H, W = tensor.shape + tensor = tensor.view(N, -1, K, H, W) + tensor = tensor.permute(0, 3, 4, 1, 2) + tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K) + return tensor + + +class DenseDetector(nn.Module): + """ + Base class for dense detector. We define a dense detector as a fully-convolutional model that + makes per-pixel (i.e. dense) predictions. + """ + + def __init__( + self, + backbone: Backbone, + head: nn.Module, + head_in_features: Optional[List[str]] = None, + *, + pixel_mean, + pixel_std, + ): + """ + Args: + backbone: backbone module + head: head module + head_in_features: backbone features to use in head. Default to all backbone features. + pixel_mean (Tuple[float]): + Values to be used for image normalization (BGR order). + To train on images of different number of channels, set different mean & std. + Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] + pixel_std (Tuple[float]): + When using pre-trained models in Detectron1 or any MSRA models, + std has been absorbed into its conv1 weights, so the std needs to be set 1. + Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + """ + super().__init__() + + self.backbone = backbone + self.head = head + if head_in_features is None: + shapes = self.backbone.output_shape() + self.head_in_features = sorted(shapes.keys(), key=lambda x: shapes[x].stride) + else: + self.head_in_features = head_in_features + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def forward(self, batched_inputs: List[Dict[str, Tensor]]): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + + * image: Tensor, image in (C, H, W) format. + * instances: Instances + + Other information that's included in the original dicts, such as: + + * "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + In training, dict[str, Tensor]: mapping from a named loss to a tensor storing the + loss. Used during training only. In inference, the standard output format, described + in :doc:`/tutorials/models`. + """ + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + features = [features[f] for f in self.head_in_features] + predictions = self.head(features) + + if self.training: + assert not torch.jit.is_scripting(), "Not supported" + assert "instances" in batched_inputs[0], "Instance annotations are missing in training!" + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + return self.forward_training(images, features, predictions, gt_instances) + else: + results = self.forward_inference(images, features, predictions) + if torch.jit.is_scripting(): + return results + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + def forward_training(self, images, features, predictions, gt_instances): + raise NotImplementedError() + + def preprocess_image(self, batched_inputs: List[Dict[str, Tensor]]): + """ + Normalize, pad and batch the input images. + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + return images + + def _transpose_dense_predictions( + self, predictions: List[List[Tensor]], dims_per_anchor: List[int] + ) -> List[List[Tensor]]: + """ + Transpose the dense per-level predictions. + + Args: + predictions: a list of outputs, each is a list of per-level + predictions with shape (N, Ai x K, Hi, Wi), where N is the + number of images, Ai is the number of anchors per location on + level i, K is the dimension of predictions per anchor. + dims_per_anchor: the value of K for each predictions. e.g. 4 for + box prediction, #classes for classification prediction. + + Returns: + List[List[Tensor]]: each prediction is transposed to (N, Hi x Wi x Ai, K). + """ + assert len(predictions) == len(dims_per_anchor) + res: List[List[Tensor]] = [] + for pred, dim_per_anchor in zip(predictions, dims_per_anchor): + pred = [permute_to_N_HWA_K(x, dim_per_anchor) for x in pred] + res.append(pred) + return res + + def _ema_update(self, name: str, value: float, initial_value: float, momentum: float = 0.9): + """ + Apply EMA update to `self.name` using `value`. + + This is mainly used for loss normalizer. In Detectron1, loss is normalized by number + of foreground samples in the batch. When batch size is 1 per GPU, #foreground has a + large variance and using it lead to lower performance. Therefore we maintain an EMA of + #foreground to stabilize the normalizer. + + Args: + name: name of the normalizer + value: the new value to update + initial_value: the initial value to start with + momentum: momentum of EMA + + Returns: + float: the updated EMA value + """ + if hasattr(self, name): + old = getattr(self, name) + else: + old = initial_value + new = old * momentum + value * (1 - momentum) + setattr(self, name, new) + return new + + def _decode_per_level_predictions( + self, + anchors: Boxes, + pred_scores: Tensor, + pred_deltas: Tensor, + score_thresh: float, + topk_candidates: int, + image_size: Tuple[int, int], + ) -> Instances: + """ + Decode boxes and classification predictions of one featuer level, by + the following steps: + 1. filter the predictions based on score threshold and top K scores. + 2. transform the box regression outputs + 3. return the predicted scores, classes and boxes + + Args: + anchors: Boxes, anchor for this feature level + pred_scores: HxWxA,K + pred_deltas: HxWxA,4 + + Returns: + Instances: with field "scores", "pred_boxes", "pred_classes". + """ + # Apply two filtering to make NMS faster. + # 1. Keep boxes with confidence score higher than threshold + keep_idxs = pred_scores > score_thresh + pred_scores = pred_scores[keep_idxs] + topk_idxs = torch.nonzero(keep_idxs) # Kx2 + + # 2. Keep top k top scoring boxes only + topk_idxs_size = topk_idxs.shape[0] + if isinstance(topk_idxs_size, Tensor): + # It's a tensor in tracing + num_topk = torch.clamp(topk_idxs_size, max=topk_candidates) + else: + num_topk = min(topk_idxs_size, topk_candidates) + pred_scores, idxs = pred_scores.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs, classes_idxs = topk_idxs.unbind(dim=1) + + pred_boxes = self.box2box_transform.apply_deltas( + pred_deltas[anchor_idxs], anchors.tensor[anchor_idxs] + ) + return Instances( + image_size, pred_boxes=Boxes(pred_boxes), scores=pred_scores, pred_classes=classes_idxs + ) + + def _decode_multi_level_predictions( + self, + anchors: List[Boxes], + pred_scores: List[Tensor], + pred_deltas: List[Tensor], + score_thresh: float, + topk_candidates: int, + image_size: Tuple[int, int], + ) -> Instances: + """ + Run `_decode_per_level_predictions` for all feature levels and concat the results. + """ + predictions = [ + self._decode_per_level_predictions( + anchors_i, + box_cls_i, + box_reg_i, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + # Iterate over every feature level + for box_cls_i, box_reg_i, anchors_i in zip(pred_scores, pred_deltas, anchors) + ] + return predictions[0].cat(predictions) # 'Instances.cat' is not scriptale but this is + + def visualize_training(self, batched_inputs, results): + """ + A function used to visualize ground truth images and final network predictions. + It shows ground truth bounding boxes on the original image and up to 20 + predicted object bounding boxes on the original image. + + Args: + batched_inputs (list): a list that contains input to the model. + results (List[Instances]): a list of #images elements returned by forward_inference(). + """ + from detectron2.utils.visualizer import Visualizer + + assert len(batched_inputs) == len( + results + ), "Cannot visualize inputs and results of different sizes" + storage = get_event_storage() + max_boxes = 20 + + image_index = 0 # only visualize a single image + img = batched_inputs[image_index]["image"] + img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) + v_gt = Visualizer(img, None) + v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes) + anno_img = v_gt.get_image() + processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1]) + predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy() + + v_pred = Visualizer(img, None) + v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes]) + prop_img = v_pred.get_image() + vis_img = np.vstack((anno_img, prop_img)) + vis_img = vis_img.transpose(2, 0, 1) + vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results" + storage.put_image(vis_name, vis_img) diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/fcos.py b/data_processing/detectron2/detectron2/modeling/meta_arch/fcos.py new file mode 100644 index 0000000..7e7140b --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/fcos.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from typing import List, Optional, Tuple +import torch +from fvcore.nn import sigmoid_focal_loss_jit +from torch import nn +from torch.nn import functional as F + +from detectron2.layers import ShapeSpec, batched_nms +from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance +from detectron2.utils.events import get_event_storage + +from ..anchor_generator import DefaultAnchorGenerator +from ..backbone import Backbone +from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss +from .dense_detector import DenseDetector +from .retinanet import RetinaNetHead + +__all__ = ["FCOS"] + +logger = logging.getLogger(__name__) + + +class FCOS(DenseDetector): + """ + Implement FCOS in :paper:`fcos`. + """ + + def __init__( + self, + *, + backbone: Backbone, + head: nn.Module, + head_in_features: Optional[List[str]] = None, + box2box_transform=None, + num_classes, + center_sampling_radius: float = 1.5, + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + test_score_thresh=0.2, + test_topk_candidates=1000, + test_nms_thresh=0.6, + max_detections_per_image=100, + pixel_mean, + pixel_std, + ): + """ + Args: + center_sampling_radius: radius of the "center" of a groundtruth box, + within which all anchor points are labeled positive. + Other arguments mean the same as in :class:`RetinaNet`. + """ + super().__init__( + backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std + ) + + self.num_classes = num_classes + + # FCOS uses one anchor point per location. + # We represent the anchor point by a box whose size equals the anchor stride. + feature_shapes = backbone.output_shape() + fpn_strides = [feature_shapes[k].stride for k in self.head_in_features] + self.anchor_generator = DefaultAnchorGenerator( + sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides + ) + + # FCOS parameterizes box regression by a linear transform, + # where predictions are normalized by anchor stride (equal to anchor size). + if box2box_transform is None: + box2box_transform = Box2BoxTransformLinear(normalize_by_size=True) + self.box2box_transform = box2box_transform + + self.center_sampling_radius = float(center_sampling_radius) + + # Loss parameters: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + + # Inference parameters: + self.test_score_thresh = test_score_thresh + self.test_topk_candidates = test_topk_candidates + self.test_nms_thresh = test_nms_thresh + self.max_detections_per_image = max_detections_per_image + + def forward_training(self, images, features, predictions, gt_instances): + # Transpose the Hi*Wi*A dimension to the middle: + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( + predictions, [self.num_classes, 4, 1] + ) + anchors = self.anchor_generator(features) + gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) + return self.losses( + anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness + ) + + @torch.no_grad() + def _match_anchors(self, gt_boxes: Boxes, anchors: List[Boxes]): + """ + Match ground-truth boxes to a set of multi-level anchors. + + Args: + gt_boxes: Ground-truth boxes from instances of an image. + anchors: List of anchors for each feature map (of different scales). + + Returns: + torch.Tensor + A tensor of shape `(M, R)`, given `M` ground-truth boxes and total + `R` anchor points from all feature levels, indicating the quality + of match between m-th box and r-th anchor. Higher value indicates + better match. + """ + # Naming convention: (M = ground-truth boxes, R = anchor points) + # Anchor points are represented as square boxes of size = stride. + num_anchors_per_level = [len(x) for x in anchors] + anchors = Boxes.cat(anchors) # (R, 4) + anchor_centers = anchors.get_centers() # (R, 2) + anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # (R, ) + + lower_bound = anchor_sizes * 4 + lower_bound[: num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1] :] = float("inf") + + gt_centers = gt_boxes.get_centers() + + # FCOS with center sampling: anchor point must be close enough to + # ground-truth box center. + center_dists = (anchor_centers[None, :, :] - gt_centers[:, None, :]).abs_() + sampling_regions = self.center_sampling_radius * anchor_sizes[None, :] + + match_quality_matrix = center_dists.max(dim=2).values < sampling_regions + + pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes) + pairwise_dist = pairwise_dist.permute(1, 0, 2) # (M, R, 4) + + # The original FCOS anchor matching rule: anchor point must be inside GT. + match_quality_matrix &= pairwise_dist.min(dim=2).values > 0 + + # Multilevel anchor matching in FCOS: each anchor is only responsible + # for certain scale range. + pairwise_dist = pairwise_dist.max(dim=2).values + match_quality_matrix &= (pairwise_dist > lower_bound[None, :]) & ( + pairwise_dist < upper_bound[None, :] + ) + # Match the GT box with minimum area, if there are multiple GT matches. + gt_areas = gt_boxes.area() # (M, ) + + match_quality_matrix = match_quality_matrix.to(torch.float32) + match_quality_matrix *= 1e8 - gt_areas[:, None] + return match_quality_matrix # (M, R) + + @torch.no_grad() + def label_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]): + """ + Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS + anchor matching rule. + + Unlike RetinaNet, there are no ignored anchors. + """ + + gt_labels, matched_gt_boxes = [], [] + + for inst in gt_instances: + if len(inst) > 0: + match_quality_matrix = self._match_anchors(inst.gt_boxes, anchors) + + # Find matched ground-truth box per anchor. Un-matched anchors are + # assigned -1. This is equivalent to using an anchor matcher as used + # in R-CNN/RetinaNet: `Matcher(thresholds=[1e-5], labels=[0, 1])` + match_quality, matched_idxs = match_quality_matrix.max(dim=0) + matched_idxs[match_quality < 1e-5] = -1 + + matched_gt_boxes_i = inst.gt_boxes.tensor[matched_idxs.clip(min=0)] + gt_labels_i = inst.gt_classes[matched_idxs.clip(min=0)] + + # Anchors with matched_idxs = -1 are labeled background. + gt_labels_i[matched_idxs < 0] = self.num_classes + else: + matched_gt_boxes_i = torch.zeros_like(Boxes.cat(anchors).tensor) + gt_labels_i = torch.full( + (len(matched_gt_boxes_i),), + fill_value=self.num_classes, + dtype=torch.long, + device=matched_gt_boxes_i.device, + ) + + gt_labels.append(gt_labels_i) + matched_gt_boxes.append(matched_gt_boxes_i) + + return gt_labels, matched_gt_boxes + + def losses( + self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness + ): + """ + This method is almost identical to :meth:`RetinaNet.losses`, with an extra + "loss_centerness" in the returned dict. + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (M, R) + + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) + num_pos_anchors = pos_mask.sum().item() + get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) + normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300) + + # classification and regression loss + gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[ + :, :, :-1 + ] # no loss for the last (background) class + loss_cls = sigmoid_focal_loss_jit( + torch.cat(pred_logits, dim=1), + gt_labels_target.to(pred_logits[0].dtype), + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) + + loss_box_reg = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type="giou", + ) + + ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # (M, R) + pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # (M, R) + ctrness_loss = F.binary_cross_entropy_with_logits( + pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum" + ) + return { + "loss_fcos_cls": loss_cls / normalizer, + "loss_fcos_loc": loss_box_reg / normalizer, + "loss_fcos_ctr": ctrness_loss / normalizer, + } + + def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]): + anchors = Boxes.cat(anchors).tensor # Rx4 + reg_targets = [self.box2box_transform.get_deltas(anchors, m) for m in gt_boxes] + reg_targets = torch.stack(reg_targets, dim=0) # NxRx4 + if len(reg_targets) == 0: + return reg_targets.new_zeros(len(reg_targets)) + left_right = reg_targets[:, :, [0, 2]] + top_bottom = reg_targets[:, :, [1, 3]] + ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] + ) + return torch.sqrt(ctrness) + + def forward_inference( + self, + images: ImageList, + features: List[torch.Tensor], + predictions: List[List[torch.Tensor]], + ): + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( + predictions, [self.num_classes, 4, 1] + ) + anchors = self.anchor_generator(features) + + results: List[Instances] = [] + for img_idx, image_size in enumerate(images.image_sizes): + scores_per_image = [ + # Multiply and sqrt centerness & classification scores + # (See eqn. 4 in https://arxiv.org/abs/2006.09214) + torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_()) + for x, y in zip(pred_logits, pred_centerness) + ] + deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] + results_per_image = self.inference_single_image( + anchors, scores_per_image, deltas_per_image, image_size + ) + results.append(results_per_image) + return results + + def inference_single_image( + self, + anchors: List[Boxes], + box_cls: List[torch.Tensor], + box_delta: List[torch.Tensor], + image_size: Tuple[int, int], + ): + """ + Identical to :meth:`RetinaNet.inference_single_image. + """ + pred = self._decode_multi_level_predictions( + anchors, + box_cls, + box_delta, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + keep = batched_nms( + pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh + ) + return pred[keep[: self.max_detections_per_image]] + + +class FCOSHead(RetinaNetHead): + """ + The head used in :paper:`fcos`. It adds an additional centerness + prediction branch on top of :class:`RetinaNetHead`. + """ + + def __init__(self, *, input_shape: List[ShapeSpec], conv_dims: List[int], **kwargs): + super().__init__(input_shape=input_shape, conv_dims=conv_dims, num_anchors=1, **kwargs) + # Unlike original FCOS, we do not add an additional learnable scale layer + # because it's found to have no benefits after normalizing regression targets by stride. + self._num_features = len(input_shape) + self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.ctrness.weight, std=0.01) + torch.nn.init.constant_(self.ctrness.bias, 0) + + def forward(self, features): + assert len(features) == self._num_features + logits = [] + bbox_reg = [] + ctrness = [] + for feature in features: + logits.append(self.cls_score(self.cls_subnet(feature))) + bbox_feature = self.bbox_subnet(feature) + bbox_reg.append(self.bbox_pred(bbox_feature)) + ctrness.append(self.ctrness(bbox_feature)) + return logits, bbox_reg, ctrness diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/panoptic_fpn.py b/data_processing/detectron2/detectron2/modeling/meta_arch/panoptic_fpn.py new file mode 100644 index 0000000..b31e1c8 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/panoptic_fpn.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from typing import Dict, List +import torch +from torch import nn + +from detectron2.config import configurable +from detectron2.structures import ImageList + +from ..postprocessing import detector_postprocess, sem_seg_postprocess +from .build import META_ARCH_REGISTRY +from .rcnn import GeneralizedRCNN +from .semantic_seg import build_sem_seg_head + +__all__ = ["PanopticFPN"] + + +@META_ARCH_REGISTRY.register() +class PanopticFPN(GeneralizedRCNN): + """ + Implement the paper :paper:`PanopticFPN`. + """ + + @configurable + def __init__( + self, + *, + sem_seg_head: nn.Module, + combine_overlap_thresh: float = 0.5, + combine_stuff_area_thresh: float = 4096, + combine_instances_score_thresh: float = 0.5, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + sem_seg_head: a module for the semantic segmentation head. + combine_overlap_thresh: combine masks into one instances if + they have enough overlap + combine_stuff_area_thresh: ignore stuff areas smaller than this threshold + combine_instances_score_thresh: ignore instances whose score is + smaller than this threshold + + Other arguments are the same as :class:`GeneralizedRCNN`. + """ + super().__init__(**kwargs) + self.sem_seg_head = sem_seg_head + # options when combining instance & semantic outputs + self.combine_overlap_thresh = combine_overlap_thresh + self.combine_stuff_area_thresh = combine_stuff_area_thresh + self.combine_instances_score_thresh = combine_instances_score_thresh + + @classmethod + def from_config(cls, cfg): + ret = super().from_config(cfg) + ret.update( + { + "combine_overlap_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH, + "combine_stuff_area_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT, + "combine_instances_score_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH, # noqa + } + ) + ret["sem_seg_head"] = build_sem_seg_head(cfg, ret["backbone"].output_shape()) + logger = logging.getLogger(__name__) + if not cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED: + logger.warning( + "PANOPTIC_FPN.COMBINED.ENABLED is no longer used. " + " model.inference(do_postprocess=) should be used to toggle postprocessing." + ) + if cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT != 1.0: + w = cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT + logger.warning( + "PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT should be replaced by weights on each ROI head." + ) + + def update_weight(x): + if isinstance(x, dict): + return {k: v * w for k, v in x.items()} + else: + return x * w + + roi_heads = ret["roi_heads"] + roi_heads.box_predictor.loss_weight = update_weight(roi_heads.box_predictor.loss_weight) + roi_heads.mask_head.loss_weight = update_weight(roi_heads.mask_head.loss_weight) + return ret + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + + For now, each item in the list is a dict that contains: + + * "image": Tensor, image in (C, H, W) format. + * "instances": Instances + * "sem_seg": semantic segmentation ground truth. + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "instances": see :meth:`GeneralizedRCNN.forward` for its format. + * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. + * "panoptic_seg": See the return value of + :func:`combine_semantic_and_instance_outputs` for its format. + """ + if not self.training: + return self.inference(batched_inputs) + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + + assert "sem_seg" in batched_inputs[0] + gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs] + gt_sem_seg = ImageList.from_tensors( + gt_sem_seg, + self.backbone.size_divisibility, + self.sem_seg_head.ignore_value, + self.backbone.padding_constraints, + ).tensor + sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg) + + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + detector_results, detector_losses = self.roi_heads( + images, features, proposals, gt_instances + ) + + losses = sem_seg_losses + losses.update(proposal_losses) + losses.update(detector_losses) + return losses + + def inference(self, batched_inputs: List[Dict[str, torch.Tensor]], do_postprocess: bool = True): + """ + Run inference on the given inputs. + + Args: + batched_inputs (list[dict]): same as in :meth:`forward` + do_postprocess (bool): whether to apply post-processing on the outputs. + + Returns: + When do_postprocess=True, see docs in :meth:`forward`. + Otherwise, returns a (list[Instances], list[Tensor]) that contains + the raw detector outputs, and raw semantic segmentation outputs. + """ + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + sem_seg_results, sem_seg_losses = self.sem_seg_head(features, None) + proposals, _ = self.proposal_generator(images, features, None) + detector_results, _ = self.roi_heads(images, features, proposals, None) + + if do_postprocess: + processed_results = [] + for sem_seg_result, detector_result, input_per_image, image_size in zip( + sem_seg_results, detector_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) + detector_r = detector_postprocess(detector_result, height, width) + + processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r}) + + panoptic_r = combine_semantic_and_instance_outputs( + detector_r, + sem_seg_r.argmax(dim=0), + self.combine_overlap_thresh, + self.combine_stuff_area_thresh, + self.combine_instances_score_thresh, + ) + processed_results[-1]["panoptic_seg"] = panoptic_r + return processed_results + else: + return detector_results, sem_seg_results + + +def combine_semantic_and_instance_outputs( + instance_results, + semantic_results, + overlap_threshold, + stuff_area_thresh, + instances_score_thresh, +): + """ + Implement a simple combining logic following + "combine_semantic_and_instance_predictions.py" in panopticapi + to produce panoptic segmentation outputs. + + Args: + instance_results: output of :func:`detector_postprocess`. + semantic_results: an (H, W) tensor, each element is the contiguous semantic + category id + + Returns: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + panoptic_seg = torch.zeros_like(semantic_results, dtype=torch.int32) + + # sort instance outputs by scores + sorted_inds = torch.argsort(-instance_results.scores) + + current_segment_id = 0 + segments_info = [] + + instance_masks = instance_results.pred_masks.to(dtype=torch.bool, device=panoptic_seg.device) + + # Add instances one-by-one, check for overlaps with existing ones + for inst_id in sorted_inds: + score = instance_results.scores[inst_id].item() + if score < instances_score_thresh: + break + mask = instance_masks[inst_id] # H,W + mask_area = mask.sum().item() + + if mask_area == 0: + continue + + intersect = (mask > 0) & (panoptic_seg > 0) + intersect_area = intersect.sum().item() + + if intersect_area * 1.0 / mask_area > overlap_threshold: + continue + + if intersect_area > 0: + mask = mask & (panoptic_seg == 0) + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + segments_info.append( + { + "id": current_segment_id, + "isthing": True, + "score": score, + "category_id": instance_results.pred_classes[inst_id].item(), + "instance_id": inst_id.item(), + } + ) + + # Add semantic results to remaining empty areas + semantic_labels = torch.unique(semantic_results).cpu().tolist() + for semantic_label in semantic_labels: + if semantic_label == 0: # 0 is a special "thing" class + continue + mask = (semantic_results == semantic_label) & (panoptic_seg == 0) + mask_area = mask.sum().item() + if mask_area < stuff_area_thresh: + continue + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + segments_info.append( + { + "id": current_segment_id, + "isthing": False, + "category_id": semantic_label, + "area": mask_area, + } + ) + + return panoptic_seg, segments_info diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/rcnn.py b/data_processing/detectron2/detectron2/modeling/meta_arch/rcnn.py new file mode 100644 index 0000000..edcbda5 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/rcnn.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import nn + +from detectron2.config import configurable +from detectron2.data.detection_utils import convert_image_to_rgb +from detectron2.layers import move_device_like +from detectron2.structures import ImageList, Instances +from detectron2.utils.events import get_event_storage +from detectron2.utils.logger import log_first_n + +from ..backbone import Backbone, build_backbone +from ..postprocessing import detector_postprocess +from ..proposal_generator import build_proposal_generator +from ..roi_heads import build_roi_heads +from .build import META_ARCH_REGISTRY + +__all__ = ["GeneralizedRCNN", "ProposalNetwork"] + + +@META_ARCH_REGISTRY.register() +class GeneralizedRCNN(nn.Module): + """ + Generalized R-CNN. Any models that contains the following three components: + 1. Per-image feature extraction (aka backbone) + 2. Region proposal generation + 3. Per-region feature extraction and prediction + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + proposal_generator: nn.Module, + roi_heads: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + input_format: Optional[str] = None, + vis_period: int = 0, + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + proposal_generator: a module that generates proposals using backbone features + roi_heads: a ROI head that performs per-region computation + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + input_format: describe the meaning of channels of input. Needed by visualization + vis_period: the period to run visualization. Set to 0 to disable. + """ + super().__init__() + self.backbone = backbone + self.proposal_generator = proposal_generator + self.roi_heads = roi_heads + + self.input_format = input_format + self.vis_period = vis_period + if vis_period > 0: + assert input_format is not None, "input_format is required for visualization!" + + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + assert ( + self.pixel_mean.shape == self.pixel_std.shape + ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + return { + "backbone": backbone, + "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), + "roi_heads": build_roi_heads(cfg, backbone.output_shape()), + "input_format": cfg.INPUT.FORMAT, + "vis_period": cfg.VIS_PERIOD, + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def visualize_training(self, batched_inputs, proposals): + """ + A function used to visualize images and proposals. It shows ground truth + bounding boxes on the original image and up to 20 top-scoring predicted + object proposals on the original image. Users can implement different + visualization functions for different models. + + Args: + batched_inputs (list): a list that contains input to the model. + proposals (list): a list that contains predicted proposals. Both + batched_inputs and proposals should have the same length. + """ + from detectron2.utils.visualizer import Visualizer + + storage = get_event_storage() + max_vis_prop = 20 + + for input, prop in zip(batched_inputs, proposals): + img = input["image"] + img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) + v_gt = Visualizer(img, None) + v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes) + anno_img = v_gt.get_image() + box_size = min(len(prop.proposal_boxes), max_vis_prop) + v_pred = Visualizer(img, None) + v_pred = v_pred.overlay_instances( + boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy() + ) + prop_img = v_pred.get_image() + vis_img = np.concatenate((anno_img, prop_img), axis=1) + vis_img = vis_img.transpose(2, 0, 1) + vis_name = "Left: GT bounding boxes; Right: Predicted proposals" + storage.put_image(vis_name, vis_img) + break # only visualize one image in a batch + + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + + * image: Tensor, image in (C, H, W) format. + * instances (optional): groundtruth :class:`Instances` + * proposals (optional): :class:`Instances`, precomputed proposals. + + Other information that's included in the original dicts, such as: + + * "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "instances" whose value is a :class:`Instances`. + The :class:`Instances` object has the following keys: + "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" + """ + if not self.training: + return self.inference(batched_inputs) + + images = self.preprocess_image(batched_inputs) + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + + features = self.backbone(images.tensor) + + if self.proposal_generator is not None: + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + else: + assert "proposals" in batched_inputs[0] + proposals = [x["proposals"].to(self.device) for x in batched_inputs] + proposal_losses = {} + + _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) + if self.vis_period > 0: + storage = get_event_storage() + if storage.iter % self.vis_period == 0: + self.visualize_training(batched_inputs, proposals) + + losses = {} + losses.update(detector_losses) + losses.update(proposal_losses) + return losses + + def inference( + self, + batched_inputs: List[Dict[str, torch.Tensor]], + detected_instances: Optional[List[Instances]] = None, + do_postprocess: bool = True, + ): + """ + Run inference on the given inputs. + + Args: + batched_inputs (list[dict]): same as in :meth:`forward` + detected_instances (None or list[Instances]): if not None, it + contains an `Instances` object per image. The `Instances` + object contains "pred_boxes" and "pred_classes" which are + known boxes in the image. + The inference will then skip the detection of bounding boxes, + and only predict other per-ROI outputs. + do_postprocess (bool): whether to apply post-processing on the outputs. + + Returns: + When do_postprocess=True, same as in :meth:`forward`. + Otherwise, a list[Instances] containing raw network outputs. + """ + assert not self.training + + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + + if detected_instances is None: + if self.proposal_generator is not None: + proposals, _ = self.proposal_generator(images, features, None) + else: + assert "proposals" in batched_inputs[0] + proposals = [x["proposals"].to(self.device) for x in batched_inputs] + + results, _ = self.roi_heads(images, features, proposals, None) + else: + detected_instances = [x.to(self.device) for x in detected_instances] + results = self.roi_heads.forward_with_given_boxes(features, detected_instances) + + if do_postprocess: + assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." + return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) + return results + + def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): + """ + Normalize, pad and batch the input images. + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + return images + + @staticmethod + def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes): + """ + Rescale the output instances to the target size. + """ + # note: private function; subject to changes + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + instances, batched_inputs, image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + +@META_ARCH_REGISTRY.register() +class ProposalNetwork(nn.Module): + """ + A meta architecture that only predicts object proposals. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + proposal_generator: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + proposal_generator: a module that generates proposals using backbone features + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.proposal_generator = proposal_generator + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + return { + "backbone": backbone, + "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def forward(self, batched_inputs): + """ + Args: + Same as in :class:`GeneralizedRCNN.forward` + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "proposals" whose value is a + :class:`Instances` with keys "proposal_boxes" and "objectness_logits". + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + features = self.backbone(images.tensor) + + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + elif "targets" in batched_inputs[0]: + log_first_n( + logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10 + ) + gt_instances = [x["targets"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + # In training, the proposals are not useful at all but we generate them anyway. + # This makes RPN-only models about 5% slower. + if self.training: + return proposal_losses + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + proposals, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"proposals": r}) + return processed_results diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/retinanet.py b/data_processing/detectron2/detectron2/modeling/meta_arch/retinanet.py new file mode 100644 index 0000000..bd72a8e --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/retinanet.py @@ -0,0 +1,439 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from typing import List, Tuple +import torch +from fvcore.nn import sigmoid_focal_loss_jit +from torch import Tensor, nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import CycleBatchNormList, ShapeSpec, batched_nms, cat, get_norm +from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from detectron2.utils.events import get_event_storage + +from ..anchor_generator import build_anchor_generator +from ..backbone import Backbone, build_backbone +from ..box_regression import Box2BoxTransform, _dense_box_regression_loss +from ..matcher import Matcher +from .build import META_ARCH_REGISTRY +from .dense_detector import DenseDetector, permute_to_N_HWA_K # noqa + +__all__ = ["RetinaNet"] + + +logger = logging.getLogger(__name__) + + +@META_ARCH_REGISTRY.register() +class RetinaNet(DenseDetector): + """ + Implement RetinaNet in :paper:`RetinaNet`. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + head: nn.Module, + head_in_features, + anchor_generator, + box2box_transform, + anchor_matcher, + num_classes, + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + smooth_l1_beta=0.0, + box_reg_loss_type="smooth_l1", + test_score_thresh=0.05, + test_topk_candidates=1000, + test_nms_thresh=0.5, + max_detections_per_image=100, + pixel_mean, + pixel_std, + vis_period=0, + input_format="BGR", + ): + """ + NOTE: this interface is experimental. + + Args: + backbone: a backbone module, must follow detectron2's backbone interface + head (nn.Module): a module that predicts logits and regression deltas + for each level from a list of per-level features + head_in_features (Tuple[str]): Names of the input feature maps to be used in head + anchor_generator (nn.Module): a module that creates anchors from a + list of features. Usually an instance of :class:`AnchorGenerator` + box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to + instance boxes + anchor_matcher (Matcher): label the anchors by matching them with ground truth. + num_classes (int): number of classes. Used to label background proposals. + + # Loss parameters: + focal_loss_alpha (float): focal_loss_alpha + focal_loss_gamma (float): focal_loss_gamma + smooth_l1_beta (float): smooth_l1_beta + box_reg_loss_type (str): Options are "smooth_l1", "giou", "diou", "ciou" + + # Inference parameters: + test_score_thresh (float): Inference cls score threshold, only anchors with + score > INFERENCE_TH are considered for inference (to improve speed) + test_topk_candidates (int): Select topk candidates before NMS + test_nms_thresh (float): Overlap threshold used for non-maximum suppression + (suppress boxes with IoU >= this threshold) + max_detections_per_image (int): + Maximum number of detections to return per image during inference + (100 is based on the limit established for the COCO dataset). + + pixel_mean, pixel_std: see :class:`DenseDetector`. + """ + super().__init__( + backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std + ) + self.num_classes = num_classes + + # Anchors + self.anchor_generator = anchor_generator + self.box2box_transform = box2box_transform + self.anchor_matcher = anchor_matcher + + # Loss parameters: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.smooth_l1_beta = smooth_l1_beta + self.box_reg_loss_type = box_reg_loss_type + # Inference parameters: + self.test_score_thresh = test_score_thresh + self.test_topk_candidates = test_topk_candidates + self.test_nms_thresh = test_nms_thresh + self.max_detections_per_image = max_detections_per_image + # Vis parameters + self.vis_period = vis_period + self.input_format = input_format + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + backbone_shape = backbone.output_shape() + feature_shapes = [backbone_shape[f] for f in cfg.MODEL.RETINANET.IN_FEATURES] + head = RetinaNetHead(cfg, feature_shapes) + anchor_generator = build_anchor_generator(cfg, feature_shapes) + return { + "backbone": backbone, + "head": head, + "anchor_generator": anchor_generator, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.RETINANET.BBOX_REG_WEIGHTS), + "anchor_matcher": Matcher( + cfg.MODEL.RETINANET.IOU_THRESHOLDS, + cfg.MODEL.RETINANET.IOU_LABELS, + allow_low_quality_matches=True, + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + "num_classes": cfg.MODEL.RETINANET.NUM_CLASSES, + "head_in_features": cfg.MODEL.RETINANET.IN_FEATURES, + # Loss parameters: + "focal_loss_alpha": cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA, + "focal_loss_gamma": cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA, + "smooth_l1_beta": cfg.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA, + "box_reg_loss_type": cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE, + # Inference parameters: + "test_score_thresh": cfg.MODEL.RETINANET.SCORE_THRESH_TEST, + "test_topk_candidates": cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST, + "test_nms_thresh": cfg.MODEL.RETINANET.NMS_THRESH_TEST, + "max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, + # Vis parameters + "vis_period": cfg.VIS_PERIOD, + "input_format": cfg.INPUT.FORMAT, + } + + def forward_training(self, images, features, predictions, gt_instances): + # Transpose the Hi*Wi*A dimension to the middle: + pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( + predictions, [self.num_classes, 4] + ) + anchors = self.anchor_generator(features) + gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) + return self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes) + + def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes): + """ + Args: + anchors (list[Boxes]): a list of #feature level Boxes + gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`. + Their shapes are (N, R) and (N, R, 4), respectively, where R is + the total number of anchors across levels, i.e. sum(Hi x Wi x Ai) + pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the + list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). + Where K is the number of classes used in `pred_logits`. + + Returns: + dict[str, Tensor]: + mapping from a named loss to a scalar tensor storing the loss. + Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (N, R) + + valid_mask = gt_labels >= 0 + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) + num_pos_anchors = pos_mask.sum().item() + get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) + normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 100) + + # classification and regression loss + gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[ + :, :-1 + ] # no loss for the last (background) class + loss_cls = sigmoid_focal_loss_jit( + cat(pred_logits, dim=1)[valid_mask], + gt_labels_target.to(pred_logits[0].dtype), + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) + + loss_box_reg = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type=self.box_reg_loss_type, + smooth_l1_beta=self.smooth_l1_beta, + ) + + return { + "loss_cls": loss_cls / normalizer, + "loss_box_reg": loss_box_reg / normalizer, + } + + @torch.no_grad() + def label_anchors(self, anchors, gt_instances): + """ + Args: + anchors (list[Boxes]): A list of #feature level Boxes. + The Boxes contains anchors of this image on the specific feature level. + gt_instances (list[Instances]): a list of N `Instances`s. The i-th + `Instances` contains the ground-truth per-instance annotations + for the i-th input image. + + Returns: + list[Tensor]: List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across all feature maps (sum(Hi * Wi * A)). + Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background. + + list[Tensor]: i-th element is a Rx4 tensor, where R is the total number of anchors + across feature maps. The values are the matched gt boxes for each anchor. + Values are undefined for those anchors not labeled as foreground. + """ + anchors = Boxes.cat(anchors) # Rx4 + + gt_labels = [] + matched_gt_boxes = [] + for gt_per_image in gt_instances: + match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors) + matched_idxs, anchor_labels = self.anchor_matcher(match_quality_matrix) + del match_quality_matrix + + if len(gt_per_image) > 0: + matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs] + + gt_labels_i = gt_per_image.gt_classes[matched_idxs] + # Anchors with label 0 are treated as background. + gt_labels_i[anchor_labels == 0] = self.num_classes + # Anchors with label -1 are ignored. + gt_labels_i[anchor_labels == -1] = -1 + else: + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes + + gt_labels.append(gt_labels_i) + matched_gt_boxes.append(matched_gt_boxes_i) + + return gt_labels, matched_gt_boxes + + def forward_inference( + self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]] + ): + pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( + predictions, [self.num_classes, 4] + ) + anchors = self.anchor_generator(features) + + results: List[Instances] = [] + for img_idx, image_size in enumerate(images.image_sizes): + scores_per_image = [x[img_idx].sigmoid_() for x in pred_logits] + deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] + results_per_image = self.inference_single_image( + anchors, scores_per_image, deltas_per_image, image_size + ) + results.append(results_per_image) + return results + + def inference_single_image( + self, + anchors: List[Boxes], + box_cls: List[Tensor], + box_delta: List[Tensor], + image_size: Tuple[int, int], + ): + """ + Single-image inference. Return bounding-box detection results by thresholding + on scores and applying non-maximum suppression (NMS). + + Arguments: + anchors (list[Boxes]): list of #feature levels. Each entry contains + a Boxes object, which contains all the anchors in that feature level. + box_cls (list[Tensor]): list of #feature levels. Each entry contains + tensor of size (H x W x A, K) + box_delta (list[Tensor]): Same shape as 'box_cls' except that K becomes 4. + image_size (tuple(H, W)): a tuple of the image height and width. + + Returns: + Same as `inference`, but for only one image. + """ + pred = self._decode_multi_level_predictions( + anchors, + box_cls, + box_delta, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + keep = batched_nms( # per-class NMS + pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh + ) + return pred[keep[: self.max_detections_per_image]] + + +class RetinaNetHead(nn.Module): + """ + The head used in RetinaNet for object classification and box regression. + It has two subnets for the two tasks, with a common structure but separate parameters. + """ + + @configurable + def __init__( + self, + *, + input_shape: List[ShapeSpec], + num_classes, + num_anchors, + conv_dims: List[int], + norm="", + prior_prob=0.01, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (List[ShapeSpec]): input shape + num_classes (int): number of classes. Used to label background proposals. + num_anchors (int): number of generated anchors + conv_dims (List[int]): dimensions for each convolution layer + norm (str or callable): + Normalization for conv layers except for the two output layers. + See :func:`detectron2.layers.get_norm` for supported types. + prior_prob (float): Prior weight for computing bias + """ + super().__init__() + + self._num_features = len(input_shape) + if norm == "BN" or norm == "SyncBN": + logger.info( + f"Using domain-specific {norm} in RetinaNetHead with len={self._num_features}." + ) + bn_class = nn.BatchNorm2d if norm == "BN" else nn.SyncBatchNorm + + def norm(c): + return CycleBatchNormList( + length=self._num_features, bn_class=bn_class, num_features=c + ) + + else: + norm_name = str(type(get_norm(norm, 32))) + if "BN" in norm_name: + logger.warning( + f"Shared BatchNorm (type={norm_name}) may not work well in RetinaNetHead." + ) + + cls_subnet = [] + bbox_subnet = [] + for in_channels, out_channels in zip( + [input_shape[0].channels] + list(conv_dims), conv_dims + ): + cls_subnet.append( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + ) + if norm: + cls_subnet.append(get_norm(norm, out_channels)) + cls_subnet.append(nn.ReLU()) + bbox_subnet.append( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + ) + if norm: + bbox_subnet.append(get_norm(norm, out_channels)) + bbox_subnet.append(nn.ReLU()) + + self.cls_subnet = nn.Sequential(*cls_subnet) + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.cls_score = nn.Conv2d( + conv_dims[-1], num_anchors * num_classes, kernel_size=3, stride=1, padding=1 + ) + self.bbox_pred = nn.Conv2d( + conv_dims[-1], num_anchors * 4, kernel_size=3, stride=1, padding=1 + ) + + # Initialization + for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]: + for layer in modules.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, mean=0, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + # Use prior in model initialization to improve stability + bias_value = -(math.log((1 - prior_prob) / prior_prob)) + torch.nn.init.constant_(self.cls_score.bias, bias_value) + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + num_anchors = build_anchor_generator(cfg, input_shape).num_cell_anchors + assert ( + len(set(num_anchors)) == 1 + ), "Using different number of anchors between levels is not currently supported!" + num_anchors = num_anchors[0] + + return { + "input_shape": input_shape, + "num_classes": cfg.MODEL.RETINANET.NUM_CLASSES, + "conv_dims": [input_shape[0].channels] * cfg.MODEL.RETINANET.NUM_CONVS, + "prior_prob": cfg.MODEL.RETINANET.PRIOR_PROB, + "norm": cfg.MODEL.RETINANET.NORM, + "num_anchors": num_anchors, + } + + def forward(self, features: List[Tensor]): + """ + Arguments: + features (list[Tensor]): FPN feature map tensors in high to low resolution. + Each tensor in the list correspond to different feature levels. + + Returns: + logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi). + The tensor predicts the classification probability + at each spatial position for each of the A anchors and K object + classes. + bbox_reg (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi). + The tensor predicts 4-vector (dx,dy,dw,dh) box + regression values for every anchor. These values are the + relative offset between the anchor and the ground truth box. + """ + assert len(features) == self._num_features + logits = [] + bbox_reg = [] + for feature in features: + logits.append(self.cls_score(self.cls_subnet(feature))) + bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature))) + return logits, bbox_reg diff --git a/data_processing/detectron2/detectron2/modeling/meta_arch/semantic_seg.py b/data_processing/detectron2/detectron2/modeling/meta_arch/semantic_seg.py new file mode 100644 index 0000000..fefbecf --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/meta_arch/semantic_seg.py @@ -0,0 +1,267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Callable, Dict, Optional, Tuple, Union +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.structures import ImageList +from detectron2.utils.registry import Registry + +from ..backbone import Backbone, build_backbone +from ..postprocessing import sem_seg_postprocess +from .build import META_ARCH_REGISTRY + +__all__ = [ + "SemanticSegmentor", + "SEM_SEG_HEADS_REGISTRY", + "SemSegFPNHead", + "build_sem_seg_head", +] + + +SEM_SEG_HEADS_REGISTRY = Registry("SEM_SEG_HEADS") +SEM_SEG_HEADS_REGISTRY.__doc__ = """ +Registry for semantic segmentation heads, which make semantic segmentation predictions +from feature maps. +""" + + +@META_ARCH_REGISTRY.register() +class SemanticSegmentor(nn.Module): + """ + Main class for semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + + For now, each item in the list is a dict that contains: + + * "image": Tensor, image in (C, H, W) format. + * "sem_seg": semantic segmentation ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "sem_seg" whose value is a + Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + + features = self.backbone(images.tensor) + + if "sem_seg" in batched_inputs[0]: + targets = [x["sem_seg"].to(self.device) for x in batched_inputs] + targets = ImageList.from_tensors( + targets, + self.backbone.size_divisibility, + self.sem_seg_head.ignore_value, + self.backbone.padding_constraints, + ).tensor + else: + targets = None + results, losses = self.sem_seg_head(features, targets) + + if self.training: + return losses + + processed_results = [] + for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = sem_seg_postprocess(result, image_size, height, width) + processed_results.append({"sem_seg": r}) + return processed_results + + +def build_sem_seg_head(cfg, input_shape): + """ + Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`. + """ + name = cfg.MODEL.SEM_SEG_HEAD.NAME + return SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) + + +@SEM_SEG_HEADS_REGISTRY.register() +class SemSegFPNHead(nn.Module): + """ + A semantic segmentation head described in :paper:`PanopticFPN`. + It takes a list of FPN features as input, and applies a sequence of + 3x3 convs and upsampling to scale all of them to the stride defined by + ``common_stride``. Then these features are added and used to make final + predictions by another 1x1 conv layer. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + conv_dims: int, + common_stride: int, + loss_weight: float = 1.0, + norm: Optional[Union[str, Callable]] = None, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + conv_dims: number of output channels for the intermediate conv layers. + common_stride: the common stride that all features will be upscaled to + loss_weight: loss weight + norm (str or callable): normalization for all conv layers + ignore_value: category id to be ignored during training. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + if not len(input_shape): + raise ValueError("SemSegFPNHead(input_shape=) cannot be empty!") + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = common_stride + self.loss_weight = loss_weight + + self.scale_heads = [] + for in_feature, stride, channels in zip( + self.in_features, feature_strides, feature_channels + ): + head_ops = [] + head_length = max(1, int(np.log2(stride) - np.log2(self.common_stride))) + for k in range(head_length): + norm_module = get_norm(norm, conv_dims) + conv = Conv2d( + channels if k == 0 else conv_dims, + conv_dims, + kernel_size=3, + stride=1, + padding=1, + bias=not norm, + norm=norm_module, + activation=F.relu, + ) + weight_init.c2_msra_fill(conv) + head_ops.append(conv) + if stride != self.common_stride: + head_ops.append( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + ) + self.scale_heads.append(nn.Sequential(*head_ops)) + self.add_module(in_feature, self.scale_heads[-1]) + self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) + weight_init.c2_msra_fill(self.predictor) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "conv_dims": cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM, + "common_stride": cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE, + "norm": cfg.MODEL.SEM_SEG_HEAD.NORM, + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + } + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = self.layers(features) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + for i, f in enumerate(self.in_features): + if i == 0: + x = self.scale_heads[i](features[f]) + else: + x = x + self.scale_heads[i](features[f]) + x = self.predictor(x) + return x + + def losses(self, predictions, targets): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + predictions = F.interpolate( + predictions, + scale_factor=self.common_stride, + mode="bilinear", + align_corners=False, + ) + loss = F.cross_entropy( + predictions, targets, reduction="mean", ignore_index=self.ignore_value + ) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses diff --git a/data_processing/detectron2/detectron2/modeling/mmdet_wrapper.py b/data_processing/detectron2/detectron2/modeling/mmdet_wrapper.py new file mode 100644 index 0000000..293b3e9 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/mmdet_wrapper.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import numpy as np +from collections import OrderedDict +from collections.abc import Mapping +from typing import Dict, List, Optional, Tuple, Union +import torch +from omegaconf import DictConfig, OmegaConf +from torch import Tensor, nn + +from detectron2.layers import ShapeSpec +from detectron2.structures import BitMasks, Boxes, ImageList, Instances +from detectron2.utils.events import get_event_storage + +from .backbone import Backbone + +logger = logging.getLogger(__name__) + + +def _to_container(cfg): + """ + mmdet will assert the type of dict/list. + So convert omegaconf objects to dict/list. + """ + if isinstance(cfg, DictConfig): + cfg = OmegaConf.to_container(cfg, resolve=True) + from mmcv.utils import ConfigDict + + return ConfigDict(cfg) + + +class MMDetBackbone(Backbone): + """ + Wrapper of mmdetection backbones to use in detectron2. + + mmdet backbones produce list/tuple of tensors, while detectron2 backbones + produce a dict of tensors. This class wraps the given backbone to produce + output in detectron2's convention, so it can be used in place of detectron2 + backbones. + """ + + def __init__( + self, + backbone: Union[nn.Module, Mapping], + neck: Union[nn.Module, Mapping, None] = None, + *, + output_shapes: List[ShapeSpec], + output_names: Optional[List[str]] = None, + ): + """ + Args: + backbone: either a backbone module or a mmdet config dict that defines a + backbone. The backbone takes a 4D image tensor and returns a + sequence of tensors. + neck: either a backbone module or a mmdet config dict that defines a + neck. The neck takes outputs of backbone and returns a + sequence of tensors. If None, no neck is used. + output_shapes: shape for every output of the backbone (or neck, if given). + stride and channels are often needed. + output_names: names for every output of the backbone (or neck, if given). + By default, will use "out0", "out1", ... + """ + super().__init__() + if isinstance(backbone, Mapping): + from mmdet.models import build_backbone + + backbone = build_backbone(_to_container(backbone)) + self.backbone = backbone + + if isinstance(neck, Mapping): + from mmdet.models import build_neck + + neck = build_neck(_to_container(neck)) + self.neck = neck + + # "Neck" weights, if any, are part of neck itself. This is the interface + # of mmdet so we follow it. Reference: + # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py + logger.info("Initializing mmdet backbone weights...") + self.backbone.init_weights() + # train() in mmdet modules is non-trivial, and has to be explicitly + # called. Reference: + # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py + self.backbone.train() + if self.neck is not None: + logger.info("Initializing mmdet neck weights ...") + if isinstance(self.neck, nn.Sequential): + for m in self.neck: + m.init_weights() + else: + self.neck.init_weights() + self.neck.train() + + self._output_shapes = output_shapes + if not output_names: + output_names = [f"out{i}" for i in range(len(output_shapes))] + self._output_names = output_names + + def forward(self, x) -> Dict[str, Tensor]: + outs = self.backbone(x) + if self.neck is not None: + outs = self.neck(outs) + assert isinstance( + outs, (list, tuple) + ), "mmdet backbone should return a list/tuple of tensors!" + if len(outs) != len(self._output_shapes): + raise ValueError( + "Length of output_shapes does not match outputs from the mmdet backbone: " + f"{len(outs)} != {len(self._output_shapes)}" + ) + return {k: v for k, v in zip(self._output_names, outs)} + + def output_shape(self) -> Dict[str, ShapeSpec]: + return {k: v for k, v in zip(self._output_names, self._output_shapes)} + + +class MMDetDetector(nn.Module): + """ + Wrapper of a mmdetection detector model, for detection and instance segmentation. + Input/output formats of this class follow detectron2's convention, so a + mmdetection model can be trained and evaluated in detectron2. + """ + + def __init__( + self, + detector: Union[nn.Module, Mapping], + *, + # Default is 32 regardless of model: + # https://github.com/open-mmlab/mmdetection/tree/master/configs/_base_/datasets + size_divisibility=32, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + detector: a mmdet detector, or a mmdet config dict that defines a detector. + size_divisibility: pad input images to multiple of this number + pixel_mean: per-channel mean to normalize input image + pixel_std: per-channel stddev to normalize input image + """ + super().__init__() + if isinstance(detector, Mapping): + from mmdet.models import build_detector + + detector = build_detector(_to_container(detector)) + self.detector = detector + self.detector.init_weights() + self.size_divisibility = size_divisibility + + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + assert ( + self.pixel_mean.shape == self.pixel_std.shape + ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" + + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, size_divisibility=self.size_divisibility).tensor + metas = [] + rescale = {"height" in x for x in batched_inputs} + if len(rescale) != 1: + raise ValueError("Some inputs have original height/width, but some don't!") + rescale = list(rescale)[0] + output_shapes = [] + for input in batched_inputs: + meta = {} + c, h, w = input["image"].shape + meta["img_shape"] = meta["ori_shape"] = (h, w, c) + if rescale: + scale_factor = np.array( + [w / input["width"], h / input["height"]] * 2, dtype="float32" + ) + ori_shape = (input["height"], input["width"]) + output_shapes.append(ori_shape) + meta["ori_shape"] = ori_shape + (c,) + else: + scale_factor = 1.0 + output_shapes.append((h, w)) + meta["scale_factor"] = scale_factor + meta["flip"] = False + padh, padw = images.shape[-2:] + meta["pad_shape"] = (padh, padw, c) + metas.append(meta) + + if self.training: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + if gt_instances[0].has("gt_masks"): + from mmdet.core import PolygonMasks as mm_PolygonMasks, BitmapMasks as mm_BitMasks + + def convert_mask(m, shape): + # mmdet mask format + if isinstance(m, BitMasks): + return mm_BitMasks(m.tensor.cpu().numpy(), shape[0], shape[1]) + else: + return mm_PolygonMasks(m.polygons, shape[0], shape[1]) + + gt_masks = [convert_mask(x.gt_masks, x.image_size) for x in gt_instances] + losses_and_metrics = self.detector.forward_train( + images, + metas, + [x.gt_boxes.tensor for x in gt_instances], + [x.gt_classes for x in gt_instances], + gt_masks=gt_masks, + ) + else: + losses_and_metrics = self.detector.forward_train( + images, + metas, + [x.gt_boxes.tensor for x in gt_instances], + [x.gt_classes for x in gt_instances], + ) + return _parse_losses(losses_and_metrics) + else: + results = self.detector.simple_test(images, metas, rescale=rescale) + results = [ + {"instances": _convert_mmdet_result(r, shape)} + for r, shape in zip(results, output_shapes) + ] + return results + + @property + def device(self): + return self.pixel_mean.device + + +# Reference: show_result() in +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/base.py +def _convert_mmdet_result(result, shape: Tuple[int, int]) -> Instances: + if isinstance(result, tuple): + bbox_result, segm_result = result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] + else: + bbox_result, segm_result = result, None + + bboxes = torch.from_numpy(np.vstack(bbox_result)) # Nx5 + bboxes, scores = bboxes[:, :4], bboxes[:, -1] + labels = [ + torch.full((bbox.shape[0],), i, dtype=torch.int32) for i, bbox in enumerate(bbox_result) + ] + labels = torch.cat(labels) + inst = Instances(shape) + inst.pred_boxes = Boxes(bboxes) + inst.scores = scores + inst.pred_classes = labels + + if segm_result is not None and len(labels) > 0: + segm_result = list(itertools.chain(*segm_result)) + segm_result = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in segm_result] + segm_result = torch.stack(segm_result, dim=0) + inst.pred_masks = segm_result + return inst + + +# reference: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/base.py +def _parse_losses(losses: Dict[str, Tensor]) -> Dict[str, Tensor]: + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + if "loss" not in loss_name: + # put metrics to storage; don't return them + storage = get_event_storage() + value = log_vars.pop(loss_name).cpu().item() + storage.put_scalar(loss_name, value) + return log_vars diff --git a/data_processing/detectron2/detectron2/modeling/poolers.py b/data_processing/detectron2/detectron2/modeling/poolers.py new file mode 100644 index 0000000..3393794 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/poolers.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Optional +import torch +from torch import nn +from torchvision.ops import RoIPool + +from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor +from detectron2.structures import Boxes +from detectron2.utils.tracing import assert_fx_safe, is_fx_tracing + +""" +To export ROIPooler to torchscript, in this file, variables that should be annotated with +`Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`. + +TODO: Correct these annotations when torchscript support `Union`. +https://github.com/pytorch/pytorch/issues/41412 +""" + +__all__ = ["ROIPooler"] + + +def assign_boxes_to_levels( + box_lists: List[Boxes], + min_level: int, + max_level: int, + canonical_box_size: int, + canonical_level: int, +): + """ + Map each box in `box_lists` to a feature map level index and return the assignment + vector. + + Args: + box_lists (list[Boxes] | list[RotatedBoxes]): A list of N Boxes or N RotatedBoxes, + where N is the number of images in the batch. + min_level (int): Smallest feature map level index. The input is considered index 0, + the output of stage 1 is index 1, and so. + max_level (int): Largest feature map level index. + canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). + canonical_level (int): The feature map level index on which a canonically-sized box + should be placed. + + Returns: + A tensor of length M, where M is the total number of boxes aggregated over all + N batch images. The memory layout corresponds to the concatenation of boxes + from all images. Each element is the feature map index, as an offset from + `self.min_level`, for the corresponding box (so value i means the box is at + `self.min_level + i`). + """ + box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists])) + # Eqn.(1) in FPN paper + level_assignments = torch.floor( + canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8) + ) + # clamp level to (min, max), in case the box size is too large or too small + # for the available feature maps + level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) + return level_assignments.to(torch.int64) - min_level + + +# script the module to avoid hardcoded device type +@torch.jit.script_if_tracing +def _convert_boxes_to_pooler_format(boxes: torch.Tensor, sizes: torch.Tensor) -> torch.Tensor: + sizes = sizes.to(device=boxes.device) + indices = torch.repeat_interleave( + torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device), sizes + ) + return cat([indices[:, None], boxes], dim=1) + + +def convert_boxes_to_pooler_format(box_lists: List[Boxes]): + """ + Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops + (see description under Returns). + + Args: + box_lists (list[Boxes] | list[RotatedBoxes]): + A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. + + Returns: + When input is list[Boxes]: + A tensor of shape (M, 5), where M is the total number of boxes aggregated over all + N batch images. + The 5 columns are (batch index, x0, y0, x1, y1), where batch index + is the index in [0, N) identifying which batch image the box with corners at + (x0, y0, x1, y1) comes from. + When input is list[RotatedBoxes]: + A tensor of shape (M, 6), where M is the total number of boxes aggregated over all + N batch images. + The 6 columns are (batch index, x_ctr, y_ctr, width, height, angle_degrees), + where batch index is the index in [0, N) identifying which batch image the + rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from. + """ + boxes = torch.cat([x.tensor for x in box_lists], dim=0) + # __len__ returns Tensor in tracing. + sizes = shapes_to_tensor([x.__len__() for x in box_lists]) + return _convert_boxes_to_pooler_format(boxes, sizes) + + +@torch.jit.script_if_tracing +def _create_zeros( + batch_target: Optional[torch.Tensor], + channels: int, + height: int, + width: int, + like_tensor: torch.Tensor, +) -> torch.Tensor: + batches = batch_target.shape[0] if batch_target is not None else 0 + sizes = (batches, channels, height, width) + return torch.zeros(sizes, dtype=like_tensor.dtype, device=like_tensor.device) + + +class ROIPooler(nn.Module): + """ + Region of interest feature map pooler that supports pooling from one or more + feature maps. + """ + + def __init__( + self, + output_size, + scales, + sampling_ratio, + pooler_type, + canonical_box_size=224, + canonical_level=4, + ): + """ + Args: + output_size (int, tuple[int] or list[int]): output size of the pooled region, + e.g., 14 x 14. If tuple or list is given, the length must be 2. + scales (list[float]): The scale for each low-level pooling op relative to + the input image. For a feature map with stride s relative to the input + image, scale is defined as 1/s. The stride must be power of 2. + When there are multiple scales, they must form a pyramid, i.e. they must be + a monotically decreasing geometric sequence with a factor of 1/2. + sampling_ratio (int): The `sampling_ratio` parameter for the ROIAlign op. + pooler_type (string): Name of the type of pooling operation that should be applied. + For instance, "ROIPool" or "ROIAlignV2". + canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). The default + is heuristically defined as 224 pixels in the FPN paper (based on ImageNet + pre-training). + canonical_level (int): The feature map level index from which a canonically-sized box + should be placed. The default is defined as level 4 (stride=16) in the FPN paper, + i.e., a box of size 224x224 will be placed on the feature with stride=16. + The box placement for all boxes will be determined from their sizes w.r.t + canonical_box_size. For example, a box whose area is 4x that of a canonical box + should be used to pool features from feature level ``canonical_level+1``. + + Note that the actual input feature maps given to this module may not have + sufficiently many levels for the input boxes. If the boxes are too large or too + small for the input feature maps, the closest level will be used. + """ + super().__init__() + + if isinstance(output_size, int): + output_size = (output_size, output_size) + assert len(output_size) == 2 + assert isinstance(output_size[0], int) and isinstance(output_size[1], int) + self.output_size = output_size + + if pooler_type == "ROIAlign": + self.level_poolers = nn.ModuleList( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=False + ) + for scale in scales + ) + elif pooler_type == "ROIAlignV2": + self.level_poolers = nn.ModuleList( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True + ) + for scale in scales + ) + elif pooler_type == "ROIPool": + self.level_poolers = nn.ModuleList( + RoIPool(output_size, spatial_scale=scale) for scale in scales + ) + elif pooler_type == "ROIAlignRotated": + self.level_poolers = nn.ModuleList( + ROIAlignRotated(output_size, spatial_scale=scale, sampling_ratio=sampling_ratio) + for scale in scales + ) + else: + raise ValueError("Unknown pooler type: {}".format(pooler_type)) + + # Map scale (defined as 1 / stride) to its feature map level under the + # assumption that stride is a power of 2. + min_level = -(math.log2(scales[0])) + max_level = -(math.log2(scales[-1])) + assert math.isclose(min_level, int(min_level)) and math.isclose( + max_level, int(max_level) + ), "Featuremap stride is not power of 2!" + self.min_level = int(min_level) + self.max_level = int(max_level) + assert ( + len(scales) == self.max_level - self.min_level + 1 + ), "[ROIPooler] Sizes of input featuremaps do not form a pyramid!" + assert 0 <= self.min_level and self.min_level <= self.max_level + self.canonical_level = canonical_level + assert canonical_box_size > 0 + self.canonical_box_size = canonical_box_size + + def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): + """ + Args: + x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those + used to construct this module. + box_lists (list[Boxes] | list[RotatedBoxes]): + A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. + The box coordinates are defined on the original image and + will be scaled by the `scales` argument of :class:`ROIPooler`. + + Returns: + Tensor: + A tensor of shape (M, C, output_size, output_size) where M is the total number of + boxes aggregated over all N batch images and C is the number of channels in `x`. + """ + num_level_assignments = len(self.level_poolers) + + if not is_fx_tracing(): + torch._assert( + isinstance(x, list) and isinstance(box_lists, list), + "Arguments to pooler must be lists", + ) + assert_fx_safe( + len(x) == num_level_assignments, + "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( + num_level_assignments, len(x) + ), + ) + assert_fx_safe( + len(box_lists) == x[0].size(0), + "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( + x[0].size(0), len(box_lists) + ), + ) + if len(box_lists) == 0: + return _create_zeros(None, x[0].shape[1], *self.output_size, x[0]) + + pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists) + + if num_level_assignments == 1: + return self.level_poolers[0](x[0], pooler_fmt_boxes) + + level_assignments = assign_boxes_to_levels( + box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level + ) + + num_channels = x[0].shape[1] + output_size = self.output_size[0] + + output = _create_zeros(pooler_fmt_boxes, num_channels, output_size, output_size, x[0]) + + for level, pooler in enumerate(self.level_poolers): + inds = nonzero_tuple(level_assignments == level)[0] + pooler_fmt_boxes_level = pooler_fmt_boxes[inds] + # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852 + output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level)) + + return output diff --git a/data_processing/detectron2/detectron2/modeling/postprocessing.py b/data_processing/detectron2/detectron2/modeling/postprocessing.py new file mode 100644 index 0000000..8451260 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/postprocessing.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch.nn import functional as F + +from detectron2.structures import Instances, ROIMasks + + +# perhaps should rename to "resize_instance" +def detector_postprocess( + results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 +): + """ + Resize the output instances. + The input images are often resized when entering an object detector. + As a result, we often need the outputs of the detector in a different + resolution from its inputs. + + This function will resize the raw outputs of an R-CNN detector + to produce outputs according to the desired output resolution. + + Args: + results (Instances): the raw outputs from the detector. + `results.image_size` contains the input image resolution the detector sees. + This object might be modified in-place. + output_height, output_width: the desired output resolution. + Returns: + Instances: the resized output from the model, based on the output resolution + """ + if isinstance(output_width, torch.Tensor): + # This shape might (but not necessarily) be tensors during tracing. + # Converts integer tensors to float temporaries to ensure true + # division is performed when computing scale_x and scale_y. + output_width_tmp = output_width.float() + output_height_tmp = output_height.float() + new_size = torch.stack([output_height, output_width]) + else: + new_size = (output_height, output_width) + output_width_tmp = output_width + output_height_tmp = output_height + + scale_x, scale_y = ( + output_width_tmp / results.image_size[1], + output_height_tmp / results.image_size[0], + ) + results = Instances(new_size, **results.get_fields()) + + if results.has("pred_boxes"): + output_boxes = results.pred_boxes + elif results.has("proposal_boxes"): + output_boxes = results.proposal_boxes + else: + output_boxes = None + assert output_boxes is not None, "Predictions must contain boxes!" + + output_boxes.scale(scale_x, scale_y) + output_boxes.clip(results.image_size) + + results = results[output_boxes.nonempty()] + + if results.has("pred_masks"): + if isinstance(results.pred_masks, ROIMasks): + roi_masks = results.pred_masks + else: + # pred_masks is a tensor of shape (N, 1, M, M) + roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) + results.pred_masks = roi_masks.to_bitmasks( + results.pred_boxes, output_height, output_width, mask_threshold + ).tensor # TODO return ROIMasks/BitMask object in the future + + if results.has("pred_keypoints"): + results.pred_keypoints[:, :, 0] *= scale_x + results.pred_keypoints[:, :, 1] *= scale_y + + return results + + +def sem_seg_postprocess(result, img_size, output_height, output_width): + """ + Return semantic segmentation predictions in the original resolution. + + The input images are often resized when entering semantic segmentor. Moreover, in same + cases, they also padded inside segmentor to be divisible by maximum network stride. + As a result, we often need the predictions of the segmentor in a different + resolution from its inputs. + + Args: + result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), + where C is the number of classes, and H, W are the height and width of the prediction. + img_size (tuple): image size that segmentor is taking as input. + output_height, output_width: the desired output resolution. + + Returns: + semantic segmentation prediction (Tensor): A tensor of the shape + (C, output_height, output_width) that contains per-pixel soft predictions. + """ + result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) + result = F.interpolate( + result, size=(output_height, output_width), mode="bilinear", align_corners=False + )[0] + return result diff --git a/data_processing/detectron2/detectron2/modeling/proposal_generator/__init__.py b/data_processing/detectron2/detectron2/modeling/proposal_generator/__init__.py new file mode 100644 index 0000000..3f4e4df --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/proposal_generator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import PROPOSAL_GENERATOR_REGISTRY, build_proposal_generator +from .rpn import RPN_HEAD_REGISTRY, build_rpn_head, RPN, StandardRPNHead + +__all__ = list(globals().keys()) diff --git a/data_processing/detectron2/detectron2/modeling/proposal_generator/build.py b/data_processing/detectron2/detectron2/modeling/proposal_generator/build.py new file mode 100644 index 0000000..34eb12d --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/proposal_generator/build.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.utils.registry import Registry + +PROPOSAL_GENERATOR_REGISTRY = Registry("PROPOSAL_GENERATOR") +PROPOSAL_GENERATOR_REGISTRY.__doc__ = """ +Registry for proposal generator, which produces object proposals from feature maps. + +The registered object will be called with `obj(cfg, input_shape)`. +The call should return a `nn.Module` object. +""" + +from . import rpn, rrpn # noqa F401 isort:skip + + +def build_proposal_generator(cfg, input_shape): + """ + Build a proposal generator from `cfg.MODEL.PROPOSAL_GENERATOR.NAME`. + The name can be "PrecomputedProposals" to use no proposal generator. + """ + name = cfg.MODEL.PROPOSAL_GENERATOR.NAME + if name == "PrecomputedProposals": + return None + + return PROPOSAL_GENERATOR_REGISTRY.get(name)(cfg, input_shape) diff --git a/data_processing/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py b/data_processing/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py new file mode 100644 index 0000000..0fdf5dc --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from typing import List, Tuple, Union +import torch + +from detectron2.layers import batched_nms, cat, move_device_like +from detectron2.structures import Boxes, Instances + +logger = logging.getLogger(__name__) + + +def _is_tracing(): + # (fixed in TORCH_VERSION >= 1.9) + if torch.jit.is_scripting(): + # https://github.com/pytorch/pytorch/issues/47379 + return False + else: + return torch.jit.is_tracing() + + +def find_top_rpn_proposals( + proposals: List[torch.Tensor], + pred_objectness_logits: List[torch.Tensor], + image_sizes: List[Tuple[int, int]], + nms_thresh: float, + pre_nms_topk: int, + post_nms_topk: int, + min_box_size: float, + training: bool, +): + """ + For each feature map, select the `pre_nms_topk` highest scoring proposals, + apply NMS, clip proposals, and remove small boxes. Return the `post_nms_topk` + highest scoring proposals among all the feature maps for each image. + + Args: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A, 4). + All proposal predictions on the feature maps. + pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A). + image_sizes (list[tuple]): sizes (h, w) for each image + nms_thresh (float): IoU threshold to use for NMS + pre_nms_topk (int): number of top k scoring proposals to keep before applying NMS. + When RPN is run on multiple feature maps (as in FPN) this number is per + feature map. + post_nms_topk (int): number of top k scoring proposals to keep after applying NMS. + When RPN is run on multiple feature maps (as in FPN) this number is total, + over all feature maps. + min_box_size (float): minimum proposal box side length in pixels (absolute units + wrt input images). + training (bool): True if proposals are to be used in training, otherwise False. + This arg exists only to support a legacy bug; look for the "NB: Legacy bug ..." + comment. + + Returns: + list[Instances]: list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i, sorted by their + objectness score in descending order. + """ + num_images = len(image_sizes) + device = ( + proposals[0].device + if torch.jit.is_scripting() + else ("cpu" if torch.jit.is_tracing() else proposals[0].device) + ) + + # 1. Select top-k anchor for every level and every image + topk_scores = [] # #lvl Tensor, each of shape N x topk + topk_proposals = [] + level_ids = [] # #lvl Tensor, each of shape (topk,) + batch_idx = move_device_like(torch.arange(num_images, device=device), proposals[0]) + for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)): + Hi_Wi_A = logits_i.shape[1] + if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing + num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk) + else: + num_proposals_i = min(Hi_Wi_A, pre_nms_topk) + + topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) + + # each is N x topk + topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4 + + topk_proposals.append(topk_proposals_i) + topk_scores.append(topk_scores_i) + level_ids.append( + move_device_like( + torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device), + proposals[0], + ) + ) + + # 2. Concat all levels together + topk_scores = cat(topk_scores, dim=1) + topk_proposals = cat(topk_proposals, dim=1) + level_ids = cat(level_ids, dim=0) + + # 3. For each image, run a per-level NMS, and choose topk results. + results: List[Instances] = [] + for n, image_size in enumerate(image_sizes): + boxes = Boxes(topk_proposals[n]) + scores_per_img = topk_scores[n] + lvl = level_ids + + valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img) + if not valid_mask.all(): + if training: + raise FloatingPointError( + "Predicted boxes or scores contain Inf/NaN. Training has diverged." + ) + boxes = boxes[valid_mask] + scores_per_img = scores_per_img[valid_mask] + lvl = lvl[valid_mask] + boxes.clip(image_size) + + # filter empty boxes + keep = boxes.nonempty(threshold=min_box_size) + if _is_tracing() or keep.sum().item() != len(boxes): + boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep] + + keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh) + # In Detectron1, there was different behavior during training vs. testing. + # (https://github.com/facebookresearch/Detectron/issues/459) + # During training, topk is over the proposals from *all* images in the training batch. + # During testing, it is over the proposals for each image separately. + # As a result, the training behavior becomes batch-dependent, + # and the configuration "POST_NMS_TOPK_TRAIN" end up relying on the batch size. + # This bug is addressed in Detectron2 to make the behavior independent of batch size. + keep = keep[:post_nms_topk] # keep is already sorted + + res = Instances(image_size) + res.proposal_boxes = boxes[keep] + res.objectness_logits = scores_per_img[keep] + results.append(res) + return results + + +def add_ground_truth_to_proposals( + gt: Union[List[Instances], List[Boxes]], proposals: List[Instances] +) -> List[Instances]: + """ + Call `add_ground_truth_to_proposals_single_image` for all images. + + Args: + gt(Union[List[Instances], List[Boxes]): list of N elements. Element i is a Instances + representing the ground-truth for image i. + proposals (list[Instances]): list of N elements. Element i is a Instances + representing the proposals for image i. + + Returns: + list[Instances]: list of N Instances. Each is the proposals for the image, + with field "proposal_boxes" and "objectness_logits". + """ + assert gt is not None + + if len(proposals) != len(gt): + raise ValueError("proposals and gt should have the same length as the number of images!") + if len(proposals) == 0: + return proposals + + return [ + add_ground_truth_to_proposals_single_image(gt_i, proposals_i) + for gt_i, proposals_i in zip(gt, proposals) + ] + + +def add_ground_truth_to_proposals_single_image( + gt: Union[Instances, Boxes], proposals: Instances +) -> Instances: + """ + Augment `proposals` with `gt`. + + Args: + Same as `add_ground_truth_to_proposals`, but with gt and proposals + per image. + + Returns: + Same as `add_ground_truth_to_proposals`, but for only one image. + """ + if isinstance(gt, Boxes): + # convert Boxes to Instances + gt = Instances(proposals.image_size, gt_boxes=gt) + + gt_boxes = gt.gt_boxes + device = proposals.objectness_logits.device + # Assign all ground-truth boxes an objectness logit corresponding to + # P(object) = sigmoid(logit) =~ 1. + gt_logit_value = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10))) + gt_logits = gt_logit_value * torch.ones(len(gt_boxes), device=device) + + # Concatenating gt_boxes with proposals requires them to have the same fields + gt_proposal = Instances(proposals.image_size, **gt.get_fields()) + gt_proposal.proposal_boxes = gt_boxes + gt_proposal.objectness_logits = gt_logits + + for key in proposals.get_fields().keys(): + assert gt_proposal.has( + key + ), "The attribute '{}' in `proposals` does not exist in `gt`".format(key) + + # NOTE: Instances.cat only use fields from the first item. Extra fields in latter items + # will be thrown away. + new_proposals = Instances.cat([proposals, gt_proposal]) + + return new_proposals diff --git a/data_processing/detectron2/detectron2/modeling/proposal_generator/rpn.py b/data_processing/detectron2/detectron2/modeling/proposal_generator/rpn.py new file mode 100644 index 0000000..99cd536 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/proposal_generator/rpn.py @@ -0,0 +1,533 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Dict, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, cat +from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from detectron2.utils.events import get_event_storage +from detectron2.utils.memory import retry_if_cuda_oom +from detectron2.utils.registry import Registry + +from ..anchor_generator import build_anchor_generator +from ..box_regression import Box2BoxTransform, _dense_box_regression_loss +from ..matcher import Matcher +from ..sampling import subsample_labels +from .build import PROPOSAL_GENERATOR_REGISTRY +from .proposal_utils import find_top_rpn_proposals + +RPN_HEAD_REGISTRY = Registry("RPN_HEAD") +RPN_HEAD_REGISTRY.__doc__ = """ +Registry for RPN heads, which take feature maps and perform +objectness classification and bounding box regression for anchors. + +The registered object will be called with `obj(cfg, input_shape)`. +The call should return a `nn.Module` object. +""" + + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + L: number of feature maps per image on which RPN is run + A: number of cell anchors (must be the same for all feature maps) + Hi, Wi: height and width of the i-th feature map + B: size of the box parameterization + +Naming convention: + + objectness: refers to the binary classification of an anchor as object vs. not object. + + deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransform`), or 5d for rotated boxes. + + pred_objectness_logits: predicted objectness scores in [-inf, +inf]; use + sigmoid(pred_objectness_logits) to estimate P(object). + + gt_labels: ground-truth binary classification labels for objectness + + pred_anchor_deltas: predicted box2box transform deltas + + gt_anchor_deltas: ground-truth box2box transform deltas +""" + + +def build_rpn_head(cfg, input_shape): + """ + Build an RPN head defined by `cfg.MODEL.RPN.HEAD_NAME`. + """ + name = cfg.MODEL.RPN.HEAD_NAME + return RPN_HEAD_REGISTRY.get(name)(cfg, input_shape) + + +@RPN_HEAD_REGISTRY.register() +class StandardRPNHead(nn.Module): + """ + Standard RPN classification and regression heads described in :paper:`Faster R-CNN`. + Uses a 3x3 conv to produce a shared hidden state from which one 1x1 conv predicts + objectness logits for each anchor and a second 1x1 conv predicts bounding-box deltas + specifying how to deform each anchor into an object proposal. + """ + + @configurable + def __init__( + self, *, in_channels: int, num_anchors: int, box_dim: int = 4, conv_dims: List[int] = (-1,) + ): + """ + NOTE: this interface is experimental. + + Args: + in_channels (int): number of input feature channels. When using multiple + input features, they must have the same number of channels. + num_anchors (int): number of anchors to predict for *each spatial position* + on the feature map. The total number of anchors for each + feature map will be `num_anchors * H * W`. + box_dim (int): dimension of a box, which is also the number of box regression + predictions to make for each anchor. An axis aligned box has + box_dim=4, while a rotated box has box_dim=5. + conv_dims (list[int]): a list of integers representing the output channels + of N conv layers. Set it to -1 to use the same number of output channels + as input channels. + """ + super().__init__() + cur_channels = in_channels + # Keeping the old variable names and structure for backwards compatiblity. + # Otherwise the old checkpoints will fail to load. + if len(conv_dims) == 1: + out_channels = cur_channels if conv_dims[0] == -1 else conv_dims[0] + # 3x3 conv for the hidden representation + self.conv = self._get_rpn_conv(cur_channels, out_channels) + cur_channels = out_channels + else: + self.conv = nn.Sequential() + for k, conv_dim in enumerate(conv_dims): + out_channels = cur_channels if conv_dim == -1 else conv_dim + if out_channels <= 0: + raise ValueError( + f"Conv output channels should be greater than 0. Got {out_channels}" + ) + conv = self._get_rpn_conv(cur_channels, out_channels) + self.conv.add_module(f"conv{k}", conv) + cur_channels = out_channels + # 1x1 conv for predicting objectness logits + self.objectness_logits = nn.Conv2d(cur_channels, num_anchors, kernel_size=1, stride=1) + # 1x1 conv for predicting box2box transform deltas + self.anchor_deltas = nn.Conv2d(cur_channels, num_anchors * box_dim, kernel_size=1, stride=1) + + # Keeping the order of weights initialization same for backwards compatiblility. + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + nn.init.normal_(layer.weight, std=0.01) + nn.init.constant_(layer.bias, 0) + + def _get_rpn_conv(self, in_channels, out_channels): + return Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + activation=nn.ReLU(), + ) + + @classmethod + def from_config(cls, cfg, input_shape): + # Standard RPN is shared across levels: + in_channels = [s.channels for s in input_shape] + assert len(set(in_channels)) == 1, "Each level must have the same channel!" + in_channels = in_channels[0] + + # RPNHead should take the same input as anchor generator + # NOTE: it assumes that creating an anchor generator does not have unwanted side effect. + anchor_generator = build_anchor_generator(cfg, input_shape) + num_anchors = anchor_generator.num_anchors + box_dim = anchor_generator.box_dim + assert ( + len(set(num_anchors)) == 1 + ), "Each level must have the same number of anchors per spatial position" + return { + "in_channels": in_channels, + "num_anchors": num_anchors[0], + "box_dim": box_dim, + "conv_dims": cfg.MODEL.RPN.CONV_DIMS, + } + + def forward(self, features: List[torch.Tensor]): + """ + Args: + features (list[Tensor]): list of feature maps + + Returns: + list[Tensor]: A list of L elements. + Element i is a tensor of shape (N, A, Hi, Wi) representing + the predicted objectness logits for all anchors. A is the number of cell anchors. + list[Tensor]: A list of L elements. Element i is a tensor of shape + (N, A*box_dim, Hi, Wi) representing the predicted "deltas" used to transform anchors + to proposals. + """ + pred_objectness_logits = [] + pred_anchor_deltas = [] + for x in features: + t = self.conv(x) + pred_objectness_logits.append(self.objectness_logits(t)) + pred_anchor_deltas.append(self.anchor_deltas(t)) + return pred_objectness_logits, pred_anchor_deltas + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class RPN(nn.Module): + """ + Region Proposal Network, introduced by :paper:`Faster R-CNN`. + """ + + @configurable + def __init__( + self, + *, + in_features: List[str], + head: nn.Module, + anchor_generator: nn.Module, + anchor_matcher: Matcher, + box2box_transform: Box2BoxTransform, + batch_size_per_image: int, + positive_fraction: float, + pre_nms_topk: Tuple[float, float], + post_nms_topk: Tuple[float, float], + nms_thresh: float = 0.7, + min_box_size: float = 0.0, + anchor_boundary_thresh: float = -1.0, + loss_weight: Union[float, Dict[str, float]] = 1.0, + box_reg_loss_type: str = "smooth_l1", + smooth_l1_beta: float = 0.0, + ): + """ + NOTE: this interface is experimental. + + Args: + in_features (list[str]): list of names of input features to use + head (nn.Module): a module that predicts logits and regression deltas + for each level from a list of per-level features + anchor_generator (nn.Module): a module that creates anchors from a + list of features. Usually an instance of :class:`AnchorGenerator` + anchor_matcher (Matcher): label the anchors by matching them with ground truth. + box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to + instance boxes + batch_size_per_image (int): number of anchors per image to sample for training + positive_fraction (float): fraction of foreground anchors to sample for training + pre_nms_topk (tuple[float]): (train, test) that represents the + number of top k proposals to select before NMS, in + training and testing. + post_nms_topk (tuple[float]): (train, test) that represents the + number of top k proposals to select after NMS, in + training and testing. + nms_thresh (float): NMS threshold used to de-duplicate the predicted proposals + min_box_size (float): remove proposal boxes with any side smaller than this threshold, + in the unit of input image pixels + anchor_boundary_thresh (float): legacy option + loss_weight (float|dict): weights to use for losses. Can be single float for weighting + all rpn losses together, or a dict of individual weightings. Valid dict keys are: + "loss_rpn_cls" - applied to classification loss + "loss_rpn_loc" - applied to box regression loss + box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou". + smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to + use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" + """ + super().__init__() + self.in_features = in_features + self.rpn_head = head + self.anchor_generator = anchor_generator + self.anchor_matcher = anchor_matcher + self.box2box_transform = box2box_transform + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + # Map from self.training state to train/test settings + self.pre_nms_topk = {True: pre_nms_topk[0], False: pre_nms_topk[1]} + self.post_nms_topk = {True: post_nms_topk[0], False: post_nms_topk[1]} + self.nms_thresh = nms_thresh + self.min_box_size = float(min_box_size) + self.anchor_boundary_thresh = anchor_boundary_thresh + if isinstance(loss_weight, float): + loss_weight = {"loss_rpn_cls": loss_weight, "loss_rpn_loc": loss_weight} + self.loss_weight = loss_weight + self.box_reg_loss_type = box_reg_loss_type + self.smooth_l1_beta = smooth_l1_beta + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + in_features = cfg.MODEL.RPN.IN_FEATURES + ret = { + "in_features": in_features, + "min_box_size": cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE, + "nms_thresh": cfg.MODEL.RPN.NMS_THRESH, + "batch_size_per_image": cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, + "positive_fraction": cfg.MODEL.RPN.POSITIVE_FRACTION, + "loss_weight": { + "loss_rpn_cls": cfg.MODEL.RPN.LOSS_WEIGHT, + "loss_rpn_loc": cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT * cfg.MODEL.RPN.LOSS_WEIGHT, + }, + "anchor_boundary_thresh": cfg.MODEL.RPN.BOUNDARY_THRESH, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS), + "box_reg_loss_type": cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE, + "smooth_l1_beta": cfg.MODEL.RPN.SMOOTH_L1_BETA, + } + + ret["pre_nms_topk"] = (cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN, cfg.MODEL.RPN.PRE_NMS_TOPK_TEST) + ret["post_nms_topk"] = (cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN, cfg.MODEL.RPN.POST_NMS_TOPK_TEST) + + ret["anchor_generator"] = build_anchor_generator(cfg, [input_shape[f] for f in in_features]) + ret["anchor_matcher"] = Matcher( + cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True + ) + ret["head"] = build_rpn_head(cfg, [input_shape[f] for f in in_features]) + return ret + + def _subsample_labels(self, label): + """ + Randomly sample a subset of positive and negative examples, and overwrite + the label vector to the ignore value (-1) for all elements that are not + included in the sample. + + Args: + labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. + """ + pos_idx, neg_idx = subsample_labels( + label, self.batch_size_per_image, self.positive_fraction, 0 + ) + # Fill with the ignore label (-1), then set positive and negative labels + label.fill_(-1) + label.scatter_(0, pos_idx, 1) + label.scatter_(0, neg_idx, 0) + return label + + @torch.jit.unused + @torch.no_grad() + def label_and_sample_anchors( + self, anchors: List[Boxes], gt_instances: List[Instances] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + anchors (list[Boxes]): anchors for each feature map. + gt_instances: the ground-truth instances for each image. + + Returns: + list[Tensor]: + List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across all feature maps R = sum(Hi * Wi * A). + Label values are in {-1, 0, 1}, with meanings: -1 = ignore; 0 = negative + class; 1 = positive class. + list[Tensor]: + i-th element is a Rx4 tensor. The values are the matched gt boxes for each + anchor. Values are undefined for those anchors not labeled as 1. + """ + anchors = Boxes.cat(anchors) + + gt_boxes = [x.gt_boxes for x in gt_instances] + image_sizes = [x.image_size for x in gt_instances] + del gt_instances + + gt_labels = [] + matched_gt_boxes = [] + for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes): + """ + image_size_i: (h, w) for the i-th image + gt_boxes_i: ground-truth boxes for i-th image + """ + + match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) + matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) + # Matching is memory-expensive and may result in CPU tensors. But the result is small + gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) + del match_quality_matrix + + if self.anchor_boundary_thresh >= 0: + # Discard anchors that go out of the boundaries of the image + # NOTE: This is legacy functionality that is turned off by default in Detectron2 + anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh) + gt_labels_i[~anchors_inside_image] = -1 + + # A vector of labels (-1, 0, 1) for each anchor + gt_labels_i = self._subsample_labels(gt_labels_i) + + if len(gt_boxes_i) == 0: + # These values won't be used anyway since the anchor is labeled as background + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + else: + # TODO wasted indexing computation for ignored boxes + matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor + + gt_labels.append(gt_labels_i) # N,AHW + matched_gt_boxes.append(matched_gt_boxes_i) + return gt_labels, matched_gt_boxes + + @torch.jit.unused + def losses( + self, + anchors: List[Boxes], + pred_objectness_logits: List[torch.Tensor], + gt_labels: List[torch.Tensor], + pred_anchor_deltas: List[torch.Tensor], + gt_boxes: List[torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """ + Return the losses from a set of RPN predictions and their associated ground-truth. + + Args: + anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each + has shape (Hi*Wi*A, B), where B is box dimension (4 or 5). + pred_objectness_logits (list[Tensor]): A list of L elements. + Element i is a tensor of shape (N, Hi*Wi*A) representing + the predicted objectness logits for all anchors. + gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`. + pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape + (N, Hi*Wi*A, 4 or 5) representing the predicted "deltas" used to transform anchors + to proposals. + gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`. + + Returns: + dict[loss name -> loss value]: A dict mapping from loss name to loss value. + Loss names are: `loss_rpn_cls` for objectness classification and + `loss_rpn_loc` for proposal localization. + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai)) + + # Log the number of positive/negative anchors per-image that's used in training + pos_mask = gt_labels == 1 + num_pos_anchors = pos_mask.sum().item() + num_neg_anchors = (gt_labels == 0).sum().item() + storage = get_event_storage() + storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images) + storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images) + + localization_loss = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type=self.box_reg_loss_type, + smooth_l1_beta=self.smooth_l1_beta, + ) + + valid_mask = gt_labels >= 0 + objectness_loss = F.binary_cross_entropy_with_logits( + cat(pred_objectness_logits, dim=1)[valid_mask], + gt_labels[valid_mask].to(torch.float32), + reduction="sum", + ) + normalizer = self.batch_size_per_image * num_images + losses = { + "loss_rpn_cls": objectness_loss / normalizer, + # The original Faster R-CNN paper uses a slightly different normalizer + # for loc loss. But it doesn't matter in practice + "loss_rpn_loc": localization_loss / normalizer, + } + losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + return losses + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + gt_instances: Optional[List[Instances]] = None, + ): + """ + Args: + images (ImageList): input images of length `N` + features (dict[str, Tensor]): input data as a mapping from feature + map name to tensor. Axis 0 represents the number of images `N` in + the input data; axes 1-3 are channels, height, and width, which may + vary between feature maps (e.g., if a feature pyramid is used). + gt_instances (list[Instances], optional): a length `N` list of `Instances`s. + Each `Instances` stores ground-truth instances for the corresponding image. + + Returns: + proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" + loss: dict[Tensor] or None + """ + features = [features[f] for f in self.in_features] + anchors = self.anchor_generator(features) + + pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features) + # Transpose the Hi*Wi*A dimension to the middle: + pred_objectness_logits = [ + # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) + score.permute(0, 2, 3, 1).flatten(1) + for score in pred_objectness_logits + ] + pred_anchor_deltas = [ + # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) + x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) + .permute(0, 3, 4, 1, 2) + .flatten(1, -2) + for x in pred_anchor_deltas + ] + + if self.training: + assert gt_instances is not None, "RPN requires gt_instances in training!" + gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) + losses = self.losses( + anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes + ) + else: + losses = {} + proposals = self.predict_proposals( + anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes + ) + return proposals, losses + + def predict_proposals( + self, + anchors: List[Boxes], + pred_objectness_logits: List[torch.Tensor], + pred_anchor_deltas: List[torch.Tensor], + image_sizes: List[Tuple[int, int]], + ): + """ + Decode all the predicted box regression deltas to proposals. Find the top proposals + by applying NMS and removing boxes that are too small. + + Returns: + proposals (list[Instances]): list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i, sorted by their + objectness score in descending order. + """ + # The proposals are treated as fixed for joint training with roi heads. + # This approach ignores the derivative w.r.t. the proposal boxes’ coordinates that + # are also network responses. + with torch.no_grad(): + pred_proposals = self._decode_proposals(anchors, pred_anchor_deltas) + return find_top_rpn_proposals( + pred_proposals, + pred_objectness_logits, + image_sizes, + self.nms_thresh, + self.pre_nms_topk[self.training], + self.post_nms_topk[self.training], + self.min_box_size, + self.training, + ) + + def _decode_proposals(self, anchors: List[Boxes], pred_anchor_deltas: List[torch.Tensor]): + """ + Transform anchors into proposals by applying the predicted anchor deltas. + + Returns: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape + (N, Hi*Wi*A, B) + """ + N = pred_anchor_deltas[0].shape[0] + proposals = [] + # For each feature map + for anchors_i, pred_anchor_deltas_i in zip(anchors, pred_anchor_deltas): + B = anchors_i.tensor.size(1) + pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B) + # Expand anchors to shape (N*Hi*Wi*A, B) + anchors_i = anchors_i.tensor.unsqueeze(0).expand(N, -1, -1).reshape(-1, B) + proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i) + # Append feature map proposals with shape (N, Hi*Wi*A, B) + proposals.append(proposals_i.view(N, -1, B)) + return proposals diff --git a/data_processing/detectron2/detectron2/modeling/proposal_generator/rrpn.py b/data_processing/detectron2/detectron2/modeling/proposal_generator/rrpn.py new file mode 100644 index 0000000..1a3cd28 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/proposal_generator/rrpn.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +from typing import Dict, List +import torch + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, batched_nms_rotated, cat +from detectron2.structures import Instances, RotatedBoxes, pairwise_iou_rotated +from detectron2.utils.memory import retry_if_cuda_oom + +from ..box_regression import Box2BoxTransformRotated +from .build import PROPOSAL_GENERATOR_REGISTRY +from .proposal_utils import _is_tracing +from .rpn import RPN + +logger = logging.getLogger(__name__) + + +def find_top_rrpn_proposals( + proposals, + pred_objectness_logits, + image_sizes, + nms_thresh, + pre_nms_topk, + post_nms_topk, + min_box_size, + training, +): + """ + For each feature map, select the `pre_nms_topk` highest scoring proposals, + apply NMS, clip proposals, and remove small boxes. Return the `post_nms_topk` + highest scoring proposals among all the feature maps if `training` is True, + otherwise, returns the highest `post_nms_topk` scoring proposals for each + feature map. + + Args: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A, 5). + All proposal predictions on the feature maps. + pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A). + image_sizes (list[tuple]): sizes (h, w) for each image + nms_thresh (float): IoU threshold to use for NMS + pre_nms_topk (int): number of top k scoring proposals to keep before applying NMS. + When RRPN is run on multiple feature maps (as in FPN) this number is per + feature map. + post_nms_topk (int): number of top k scoring proposals to keep after applying NMS. + When RRPN is run on multiple feature maps (as in FPN) this number is total, + over all feature maps. + min_box_size(float): minimum proposal box side length in pixels (absolute units wrt + input images). + training (bool): True if proposals are to be used in training, otherwise False. + This arg exists only to support a legacy bug; look for the "NB: Legacy bug ..." + comment. + + Returns: + proposals (list[Instances]): list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i. + """ + num_images = len(image_sizes) + device = proposals[0].device + + # 1. Select top-k anchor for every level and every image + topk_scores = [] # #lvl Tensor, each of shape N x topk + topk_proposals = [] + level_ids = [] # #lvl Tensor, each of shape (topk,) + batch_idx = torch.arange(num_images, device=device) + for level_id, proposals_i, logits_i in zip( + itertools.count(), proposals, pred_objectness_logits + ): + Hi_Wi_A = logits_i.shape[1] + if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing + num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk) + else: + num_proposals_i = min(Hi_Wi_A, pre_nms_topk) + + topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) + + # each is N x topk + topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 5 + + topk_proposals.append(topk_proposals_i) + topk_scores.append(topk_scores_i) + level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device)) + + # 2. Concat all levels together + topk_scores = cat(topk_scores, dim=1) + topk_proposals = cat(topk_proposals, dim=1) + level_ids = cat(level_ids, dim=0) + + # 3. For each image, run a per-level NMS, and choose topk results. + results = [] + for n, image_size in enumerate(image_sizes): + boxes = RotatedBoxes(topk_proposals[n]) + scores_per_img = topk_scores[n] + lvl = level_ids + + valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img) + if not valid_mask.all(): + if training: + raise FloatingPointError( + "Predicted boxes or scores contain Inf/NaN. Training has diverged." + ) + boxes = boxes[valid_mask] + scores_per_img = scores_per_img[valid_mask] + lvl = lvl[valid_mask] + boxes.clip(image_size) + + # filter empty boxes + keep = boxes.nonempty(threshold=min_box_size) + if _is_tracing() or keep.sum().item() != len(boxes): + boxes, scores_per_img, lvl = (boxes[keep], scores_per_img[keep], lvl[keep]) + + keep = batched_nms_rotated(boxes.tensor, scores_per_img, lvl, nms_thresh) + # In Detectron1, there was different behavior during training vs. testing. + # (https://github.com/facebookresearch/Detectron/issues/459) + # During training, topk is over the proposals from *all* images in the training batch. + # During testing, it is over the proposals for each image separately. + # As a result, the training behavior becomes batch-dependent, + # and the configuration "POST_NMS_TOPK_TRAIN" end up relying on the batch size. + # This bug is addressed in Detectron2 to make the behavior independent of batch size. + keep = keep[:post_nms_topk] + + res = Instances(image_size) + res.proposal_boxes = boxes[keep] + res.objectness_logits = scores_per_img[keep] + results.append(res) + return results + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class RRPN(RPN): + """ + Rotated Region Proposal Network described in :paper:`RRPN`. + """ + + @configurable + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.anchor_boundary_thresh >= 0: + raise NotImplementedError( + "anchor_boundary_thresh is a legacy option not implemented for RRPN." + ) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["box2box_transform"] = Box2BoxTransformRotated(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS) + return ret + + @torch.no_grad() + def label_and_sample_anchors(self, anchors: List[RotatedBoxes], gt_instances: List[Instances]): + """ + Args: + anchors (list[RotatedBoxes]): anchors for each feature map. + gt_instances: the ground-truth instances for each image. + + Returns: + list[Tensor]: + List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across feature maps. Label values are in {-1, 0, 1}, + with meanings: -1 = ignore; 0 = negative class; 1 = positive class. + list[Tensor]: + i-th element is a Nx5 tensor, where N is the total number of anchors across + feature maps. The values are the matched gt boxes for each anchor. + Values are undefined for those anchors not labeled as 1. + """ + anchors = RotatedBoxes.cat(anchors) + + gt_boxes = [x.gt_boxes for x in gt_instances] + del gt_instances + + gt_labels = [] + matched_gt_boxes = [] + for gt_boxes_i in gt_boxes: + """ + gt_boxes_i: ground-truth boxes for i-th image + """ + match_quality_matrix = retry_if_cuda_oom(pairwise_iou_rotated)(gt_boxes_i, anchors) + matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) + # Matching is memory-expensive and may result in CPU tensors. But the result is small + gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) + + # A vector of labels (-1, 0, 1) for each anchor + gt_labels_i = self._subsample_labels(gt_labels_i) + + if len(gt_boxes_i) == 0: + # These values won't be used anyway since the anchor is labeled as background + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + else: + # TODO wasted indexing computation for ignored boxes + matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor + + gt_labels.append(gt_labels_i) # N,AHW + matched_gt_boxes.append(matched_gt_boxes_i) + return gt_labels, matched_gt_boxes + + @torch.no_grad() + def predict_proposals(self, anchors, pred_objectness_logits, pred_anchor_deltas, image_sizes): + pred_proposals = self._decode_proposals(anchors, pred_anchor_deltas) + return find_top_rrpn_proposals( + pred_proposals, + pred_objectness_logits, + image_sizes, + self.nms_thresh, + self.pre_nms_topk[self.training], + self.post_nms_topk[self.training], + self.min_box_size, + self.training, + ) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/__init__.py b/data_processing/detectron2/detectron2/modeling/roi_heads/__init__.py new file mode 100644 index 0000000..d13e9c5 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .box_head import ROI_BOX_HEAD_REGISTRY, build_box_head, FastRCNNConvFCHead +from .keypoint_head import ( + ROI_KEYPOINT_HEAD_REGISTRY, + build_keypoint_head, + BaseKeypointRCNNHead, + KRCNNConvDeconvUpsampleHead, +) +from .mask_head import ( + ROI_MASK_HEAD_REGISTRY, + build_mask_head, + BaseMaskRCNNHead, + MaskRCNNConvUpsampleHead, +) +from .roi_heads import ( + ROI_HEADS_REGISTRY, + ROIHeads, + Res5ROIHeads, + StandardROIHeads, + build_roi_heads, + select_foreground_proposals, +) +from .cascade_rcnn import CascadeROIHeads +from .rotated_fast_rcnn import RROIHeads +from .fast_rcnn import FastRCNNOutputLayers + +from . import cascade_rcnn # isort:skip + +__all__ = list(globals().keys()) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/box_head.py b/data_processing/detectron2/detectron2/modeling/roi_heads/box_head.py new file mode 100644 index 0000000..5d0370b --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/box_head.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import List +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.utils.registry import Registry + +__all__ = ["FastRCNNConvFCHead", "build_box_head", "ROI_BOX_HEAD_REGISTRY"] + +ROI_BOX_HEAD_REGISTRY = Registry("ROI_BOX_HEAD") +ROI_BOX_HEAD_REGISTRY.__doc__ = """ +Registry for box heads, which make box predictions from per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_BOX_HEAD_REGISTRY.register() +class FastRCNNConvFCHead(nn.Sequential): + """ + A head with several 3x3 conv layers (each followed by norm & relu) and then + several fc layers (each followed by relu). + """ + + @configurable + def __init__( + self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm="" + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature. + conv_dims (list[int]): the output dimensions of the conv layers + fc_dims (list[int]): the output dimensions of the fc layers + conv_norm (str or callable): normalization for the conv layers. + See :func:`detectron2.layers.get_norm` for supported types. + """ + super().__init__() + assert len(conv_dims) + len(fc_dims) > 0 + + self._output_size = (input_shape.channels, input_shape.height, input_shape.width) + + self.conv_norm_relus = [] + for k, conv_dim in enumerate(conv_dims): + conv = Conv2d( + self._output_size[0], + conv_dim, + kernel_size=3, + padding=1, + bias=not conv_norm, + norm=get_norm(conv_norm, conv_dim), + activation=nn.ReLU(), + ) + self.add_module("conv{}".format(k + 1), conv) + self.conv_norm_relus.append(conv) + self._output_size = (conv_dim, self._output_size[1], self._output_size[2]) + + self.fcs = [] + for k, fc_dim in enumerate(fc_dims): + if k == 0: + self.add_module("flatten", nn.Flatten()) + fc = nn.Linear(int(np.prod(self._output_size)), fc_dim) + self.add_module("fc{}".format(k + 1), fc) + self.add_module("fc_relu{}".format(k + 1), nn.ReLU()) + self.fcs.append(fc) + self._output_size = fc_dim + + for layer in self.conv_norm_relus: + weight_init.c2_msra_fill(layer) + for layer in self.fcs: + weight_init.c2_xavier_fill(layer) + + @classmethod + def from_config(cls, cfg, input_shape): + num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV + conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM + num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC + fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM + return { + "input_shape": input_shape, + "conv_dims": [conv_dim] * num_conv, + "fc_dims": [fc_dim] * num_fc, + "conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM, + } + + def forward(self, x): + for layer in self: + x = layer(x) + return x + + @property + @torch.jit.unused + def output_shape(self): + """ + Returns: + ShapeSpec: the output feature shape + """ + o = self._output_size + if isinstance(o, int): + return ShapeSpec(channels=o) + else: + return ShapeSpec(channels=o[0], height=o[1], width=o[2]) + + +def build_box_head(cfg, input_shape): + """ + Build a box head defined by `cfg.MODEL.ROI_BOX_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_BOX_HEAD.NAME + return ROI_BOX_HEAD_REGISTRY.get(name)(cfg, input_shape) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/cascade_rcnn.py b/data_processing/detectron2/detectron2/modeling/roi_heads/cascade_rcnn.py new file mode 100644 index 0000000..a0ca70f --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/cascade_rcnn.py @@ -0,0 +1,299 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch +from torch import nn +from torch.autograd.function import Function + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec +from detectron2.structures import Boxes, Instances, pairwise_iou +from detectron2.utils.events import get_event_storage + +from ..box_regression import Box2BoxTransform +from ..matcher import Matcher +from ..poolers import ROIPooler +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers, fast_rcnn_inference +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads + + +class _ScaleGradient(Function): + @staticmethod + def forward(ctx, input, scale): + ctx.scale = scale + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output * ctx.scale, None + + +@ROI_HEADS_REGISTRY.register() +class CascadeROIHeads(StandardROIHeads): + """ + The ROI heads that implement :paper:`Cascade R-CNN`. + """ + + @configurable + def __init__( + self, + *, + box_in_features: List[str], + box_pooler: ROIPooler, + box_heads: List[nn.Module], + box_predictors: List[nn.Module], + proposal_matchers: List[Matcher], + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + box_pooler (ROIPooler): pooler that extracts region features from given boxes + box_heads (list[nn.Module]): box head for each cascade stage + box_predictors (list[nn.Module]): box predictor for each cascade stage + proposal_matchers (list[Matcher]): matcher with different IoU thresholds to + match boxes with ground truth for each stage. The first matcher matches + RPN proposals with ground truth, the other matchers use boxes predicted + by the previous stage as proposals and match them with ground truth. + """ + assert "proposal_matcher" not in kwargs, ( + "CascadeROIHeads takes 'proposal_matchers=' for each stage instead " + "of one 'proposal_matcher='." + ) + # The first matcher matches RPN proposals with ground truth, done in the base class + kwargs["proposal_matcher"] = proposal_matchers[0] + num_stages = self.num_cascade_stages = len(box_heads) + box_heads = nn.ModuleList(box_heads) + box_predictors = nn.ModuleList(box_predictors) + assert len(box_predictors) == num_stages, f"{len(box_predictors)} != {num_stages}!" + assert len(proposal_matchers) == num_stages, f"{len(proposal_matchers)} != {num_stages}!" + super().__init__( + box_in_features=box_in_features, + box_pooler=box_pooler, + box_head=box_heads, + box_predictor=box_predictors, + **kwargs, + ) + self.proposal_matchers = proposal_matchers + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret.pop("proposal_matcher") + return ret + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + cascade_ious = cfg.MODEL.ROI_BOX_CASCADE_HEAD.IOUS + assert len(cascade_bbox_reg_weights) == len(cascade_ious) + assert cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG, \ + "CascadeROIHeads only support class-agnostic regression now!" + assert cascade_ious[0] == cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS[0] + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features] + # Check all channel counts are equal + assert len(set(in_channels)) == 1, in_channels + in_channels = in_channels[0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + pooled_shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + + box_heads, box_predictors, proposal_matchers = [], [], [] + for match_iou, bbox_reg_weights in zip(cascade_ious, cascade_bbox_reg_weights): + box_head = build_box_head(cfg, pooled_shape) + box_heads.append(box_head) + box_predictors.append( + FastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform(weights=bbox_reg_weights), + ) + ) + proposal_matchers.append(Matcher([match_iou], [0, 1], allow_low_quality_matches=False)) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_heads": box_heads, + "box_predictors": box_predictors, + "proposal_matchers": proposal_matchers, + } + + def forward(self, images, features, proposals, targets=None): + del images + if self.training: + proposals = self.label_and_sample_proposals(proposals, targets) + + if self.training: + # Need targets to box head + losses = self._forward_box(features, proposals, targets) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def _forward_box(self, features, proposals, targets=None): + """ + Args: + features, targets: the same as in + Same as in :meth:`ROIHeads.forward`. + proposals (list[Instances]): the per-image object proposals with + their matching ground truth. + Each has fields "proposal_boxes", and "objectness_logits", + "gt_classes", "gt_boxes". + """ + features = [features[f] for f in self.box_in_features] + head_outputs = [] # (predictor, predictions, proposals) + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + for k in range(self.num_cascade_stages): + if k > 0: + # The output boxes of the previous stage are used to create the input + # proposals of the next stage. + proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes) + if self.training: + proposals = self._match_and_label_boxes(proposals, k, targets) + predictions = self._run_stage(features, proposals, k) + prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals) + head_outputs.append((self.box_predictor[k], predictions, proposals)) + + if self.training: + losses = {} + storage = get_event_storage() + for stage, (predictor, predictions, proposals) in enumerate(head_outputs): + with storage.name_scope("stage{}".format(stage)): + stage_losses = predictor.losses(predictions, proposals) + losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()}) + return losses + else: + # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) + scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] + + # Average the scores across heads + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage) + ] + # Use the boxes of the last head + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes(predictions, proposals) + pred_instances, _ = fast_rcnn_inference( + boxes, + scores, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + ) + return pred_instances + + @torch.no_grad() + def _match_and_label_boxes(self, proposals, stage, targets): + """ + Match proposals with groundtruth using the matcher at the given stage. + Label the proposals as foreground or background based on the match. + + Args: + proposals (list[Instances]): One Instances for each image, with + the field "proposal_boxes". + stage (int): the current stage + targets (list[Instances]): the ground truth instances + + Returns: + list[Instances]: the same proposals, but with fields "gt_classes" and "gt_boxes" + """ + num_fg_samples, num_bg_samples = [], [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + match_quality_matrix = pairwise_iou( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + # proposal_labels are 0 or 1 + matched_idxs, proposal_labels = self.proposal_matchers[stage](match_quality_matrix) + if len(targets_per_image) > 0: + gt_classes = targets_per_image.gt_classes[matched_idxs] + # Label unmatched proposals (0 label from matcher) as background (label=num_classes) + gt_classes[proposal_labels == 0] = self.num_classes + gt_boxes = targets_per_image.gt_boxes[matched_idxs] + else: + gt_classes = torch.zeros_like(matched_idxs) + self.num_classes + gt_boxes = Boxes( + targets_per_image.gt_boxes.tensor.new_zeros((len(proposals_per_image), 4)) + ) + proposals_per_image.gt_classes = gt_classes + proposals_per_image.gt_boxes = gt_boxes + + num_fg_samples.append((proposal_labels == 1).sum().item()) + num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1]) + + # Log the number of fg/bg samples in each stage + storage = get_event_storage() + storage.put_scalar( + "stage{}/roi_head/num_fg_samples".format(stage), + sum(num_fg_samples) / len(num_fg_samples), + ) + storage.put_scalar( + "stage{}/roi_head/num_bg_samples".format(stage), + sum(num_bg_samples) / len(num_bg_samples), + ) + return proposals + + def _run_stage(self, features, proposals, stage): + """ + Args: + features (list[Tensor]): #lvl input features to ROIHeads + proposals (list[Instances]): #image Instances, with the field "proposal_boxes" + stage (int): the current stage + + Returns: + Same output as `FastRCNNOutputLayers.forward()`. + """ + box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) + # The original implementation averages the losses among heads, + # but scale up the parameter gradients of the heads. + # This is equivalent to adding the losses among heads, + # but scale down the gradients on features. + if self.training: + box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) + box_features = self.box_head[stage](box_features) + return self.box_predictor[stage](box_features) + + def _create_proposals_from_boxes(self, boxes, image_sizes): + """ + Args: + boxes (list[Tensor]): per-image predicted boxes, each of shape Ri x 4 + image_sizes (list[tuple]): list of image shapes in (h, w) + + Returns: + list[Instances]: per-image proposals with the given boxes. + """ + # Just like RPN, the proposals should not have gradients + boxes = [Boxes(b.detach()) for b in boxes] + proposals = [] + for boxes_per_image, image_size in zip(boxes, image_sizes): + boxes_per_image.clip(image_size) + if self.training: + # do not filter empty boxes at inference time, + # because the scores from each stage need to be aligned and added later + boxes_per_image = boxes_per_image[boxes_per_image.nonempty()] + prop = Instances(image_size) + prop.proposal_boxes = boxes_per_image + proposals.append(prop) + return proposals diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py b/data_processing/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py new file mode 100644 index 0000000..039e249 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple +from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss +from detectron2.structures import Boxes, Instances +from detectron2.utils.events import get_event_storage + +__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"] + + +logger = logging.getLogger(__name__) + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + R: number of ROIs, combined over all images, in the minibatch + Ri: number of ROIs in image i + K: number of foreground classes. E.g.,there are 80 foreground classes in COCO. + +Naming convention: + + deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransform`). + + pred_class_logits: predicted class scores in [-inf, +inf]; use + softmax(pred_class_logits) to estimate P(class). + + gt_classes: ground-truth classification labels in [0, K], where [0, K) represent + foreground object classes and K represents the background class. + + pred_proposal_deltas: predicted box2box transform deltas for transforming proposals + to detection box predictions. + + gt_proposal_deltas: ground-truth box2box transform deltas +""" + + +def fast_rcnn_inference( + boxes: List[torch.Tensor], + scores: List[torch.Tensor], + image_shapes: List[Tuple[int, int]], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, +): + """ + Call `fast_rcnn_inference_single_image` for all images. + + Args: + boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic + boxes for each image. Element i has shape (Ri, K * 4) if doing + class-specific regression, or (Ri, 4) if doing class-agnostic + regression, where Ri is the number of predicted objects for image i. + This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`. + scores (list[Tensor]): A list of Tensors of predicted class scores for each image. + Element i has shape (Ri, K + 1), where Ri is the number of predicted objects + for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. + image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch. + score_thresh (float): Only return detections with a confidence score exceeding this + threshold. + nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + topk_per_image (int): The number of top scoring detections to return. Set < 0 to return + all detections. + + Returns: + instances: (list[Instances]): A list of N instances, one for each image in the batch, + that stores the topk most confidence detections. + kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates + the corresponding boxes/scores index in [0, Ri) from the input, for image i. + """ + result_per_image = [ + fast_rcnn_inference_single_image( + boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image + ) + for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) + ] + return [x[0] for x in result_per_image], [x[1] for x in result_per_image] + + +def _log_classification_stats(pred_logits, gt_classes, prefix="fast_rcnn"): + """ + Log the classification metrics to EventStorage. + + Args: + pred_logits: Rx(K+1) logits. The last column is for background class. + gt_classes: R labels + """ + num_instances = gt_classes.numel() + if num_instances == 0: + return + pred_classes = pred_logits.argmax(dim=1) + bg_class_ind = pred_logits.shape[1] - 1 + + fg_inds = (gt_classes >= 0) & (gt_classes < bg_class_ind) + num_fg = fg_inds.nonzero().numel() + fg_gt_classes = gt_classes[fg_inds] + fg_pred_classes = pred_classes[fg_inds] + + num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel() + num_accurate = (pred_classes == gt_classes).nonzero().numel() + fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel() + + storage = get_event_storage() + storage.put_scalar(f"{prefix}/cls_accuracy", num_accurate / num_instances) + if num_fg > 0: + storage.put_scalar(f"{prefix}/fg_cls_accuracy", fg_num_accurate / num_fg) + storage.put_scalar(f"{prefix}/false_negative", num_false_negative / num_fg) + + +def fast_rcnn_inference_single_image( + boxes, + scores, + image_shape: Tuple[int, int], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, +): + """ + Single-image inference. Return bounding-box detection results by thresholding + on scores and applying non-maximum suppression (NMS). + + Args: + Same as `fast_rcnn_inference`, but with boxes, scores, and image shapes + per image. + + Returns: + Same as `fast_rcnn_inference`, but for only one image. + """ + valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) + if not valid_mask.all(): + boxes = boxes[valid_mask] + scores = scores[valid_mask] + + scores = scores[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // 4 + # Convert to Boxes to use the `clip` function ... + boxes = Boxes(boxes.reshape(-1, 4)) + boxes.clip(image_shape) + boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4 + + # 1. Filter results based on detection scores. It can make NMS more efficient + # by filtering out low-confidence detections. + filter_mask = scores > score_thresh # R x K + # R' x 2. First column contains indices of the R predictions; + # Second column contains indices of classes. + filter_inds = filter_mask.nonzero() + if num_bbox_reg_classes == 1: + boxes = boxes[filter_inds[:, 0], 0] + else: + boxes = boxes[filter_mask] + scores = scores[filter_mask] + + # 2. Apply NMS for each class independently. + keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh) + if topk_per_image >= 0: + keep = keep[:topk_per_image] + boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] + + result = Instances(image_shape) + result.pred_boxes = Boxes(boxes) + result.scores = scores + result.pred_classes = filter_inds[:, 1] + return result, filter_inds[:, 0] + + +class FastRCNNOutputLayers(nn.Module): + """ + Two linear layers for predicting Fast R-CNN outputs: + + 1. proposal-to-detection box regression deltas + 2. classification scores + """ + + @configurable + def __init__( + self, + input_shape: ShapeSpec, + *, + box2box_transform, + num_classes: int, + test_score_thresh: float = 0.0, + test_nms_thresh: float = 0.5, + test_topk_per_image: int = 100, + cls_agnostic_bbox_reg: bool = False, + smooth_l1_beta: float = 0.0, + box_reg_loss_type: str = "smooth_l1", + loss_weight: Union[float, Dict[str, float]] = 1.0, + use_fed_loss: bool = False, + use_sigmoid_ce: bool = False, + get_fed_loss_cls_weights: Optional[Callable] = None, + fed_loss_num_classes: int = 50, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature to this module + box2box_transform (Box2BoxTransform or Box2BoxTransformRotated): + num_classes (int): number of foreground classes + test_score_thresh (float): threshold to filter predictions results. + test_nms_thresh (float): NMS threshold for prediction results. + test_topk_per_image (int): number of top predictions to produce per image. + cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression + smooth_l1_beta (float): transition point from L1 to L2 loss. Only used if + `box_reg_loss_type` is "smooth_l1" + box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou", + "diou", "ciou" + loss_weight (float|dict): weights to use for losses. Can be single float for weighting + all losses, or a dict of individual weightings. Valid dict keys are: + * "loss_cls": applied to classification loss + * "loss_box_reg": applied to box regression loss + use_fed_loss (bool): whether to use federated loss which samples additional negative + classes to calculate the loss + use_sigmoid_ce (bool): whether to calculate the loss using weighted average of binary + cross entropy with logits. This could be used together with federated loss + get_fed_loss_cls_weights (Callable): a callable which takes dataset name and frequency + weight power, and returns the probabilities to sample negative classes for + federated loss. The implementation can be found in + detectron2/data/detection_utils.py + fed_loss_num_classes (int): number of federated classes to keep in total + """ + super().__init__() + if isinstance(input_shape, int): # some backward compatibility + input_shape = ShapeSpec(channels=input_shape) + self.num_classes = num_classes + input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) + # prediction layer for num_classes foreground classes and one background class (hence + 1) + self.cls_score = nn.Linear(input_size, num_classes + 1) + num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes + box_dim = len(box2box_transform.weights) + self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim) + + nn.init.normal_(self.cls_score.weight, std=0.01) + nn.init.normal_(self.bbox_pred.weight, std=0.001) + for l in [self.cls_score, self.bbox_pred]: + nn.init.constant_(l.bias, 0) + + self.box2box_transform = box2box_transform + self.smooth_l1_beta = smooth_l1_beta + self.test_score_thresh = test_score_thresh + self.test_nms_thresh = test_nms_thresh + self.test_topk_per_image = test_topk_per_image + self.box_reg_loss_type = box_reg_loss_type + if isinstance(loss_weight, float): + loss_weight = {"loss_cls": loss_weight, "loss_box_reg": loss_weight} + self.loss_weight = loss_weight + self.use_fed_loss = use_fed_loss + self.use_sigmoid_ce = use_sigmoid_ce + self.fed_loss_num_classes = fed_loss_num_classes + + if self.use_fed_loss: + assert self.use_sigmoid_ce, "Please use sigmoid cross entropy loss with federated loss" + fed_loss_cls_weights = get_fed_loss_cls_weights() + assert ( + len(fed_loss_cls_weights) == self.num_classes + ), "Please check the provided fed_loss_cls_weights. Their size should match num_classes" + self.register_buffer("fed_loss_cls_weights", fed_loss_cls_weights) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "input_shape": input_shape, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS), + # fmt: off + "num_classes" : cfg.MODEL.ROI_HEADS.NUM_CLASSES, + "cls_agnostic_bbox_reg" : cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG, + "smooth_l1_beta" : cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA, + "test_score_thresh" : cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST, + "test_nms_thresh" : cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, + "test_topk_per_image" : cfg.TEST.DETECTIONS_PER_IMAGE, + "box_reg_loss_type" : cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE, + "loss_weight" : {"loss_box_reg": cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT}, # noqa + "use_fed_loss" : cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, + "use_sigmoid_ce" : cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, + "get_fed_loss_cls_weights" : lambda: get_fed_loss_cls_weights(dataset_names=cfg.DATASETS.TRAIN, freq_weight_power=cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER), # noqa + "fed_loss_num_classes" : cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES, + # fmt: on + } + + def forward(self, x): + """ + Args: + x: per-region features of shape (N, ...) for N bounding boxes to predict. + + Returns: + (Tensor, Tensor): + First tensor: shape (N,K+1), scores for each of the N box. Each row contains the + scores for K object categories and 1 background class. + + Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4), + or (N,4) for class-agnostic regression. + """ + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + scores = self.cls_score(x) + proposal_deltas = self.bbox_pred(x) + return scores, proposal_deltas + + def losses(self, predictions, proposals): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were used + to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``, + ``gt_classes`` are expected. + + Returns: + Dict[str, Tensor]: dict of losses + """ + scores, proposal_deltas = predictions + + # parse classification outputs + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) + ) + _log_classification_stats(scores, gt_classes) + + # parse box regression outputs + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 + assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" + # If "gt_boxes" does not exist, the proposals must be all negative and + # should not be included in regression loss computation. + # Here we just use proposal_boxes as an arbitrary placeholder because its + # value won't be used in self.box_reg_loss(). + gt_boxes = cat( + [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) + + if self.use_sigmoid_ce: + loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) + else: + loss_cls = cross_entropy(scores, gt_classes, reduction="mean") + + losses = { + "loss_cls": loss_cls, + "loss_box_reg": self.box_reg_loss( + proposal_boxes, gt_boxes, proposal_deltas, gt_classes + ), + } + return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + + # Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/fed_loss.py # noqa + # with slight modifications + def get_fed_loss_classes(self, gt_classes, num_fed_loss_classes, num_classes, weight): + """ + Args: + gt_classes: a long tensor of shape R that contains the gt class label of each proposal. + num_fed_loss_classes: minimum number of classes to keep when calculating federated loss. + Will sample negative classes if number of unique gt_classes is smaller than this value. + num_classes: number of foreground classes + weight: probabilities used to sample negative classes + + Returns: + Tensor: + classes to keep when calculating the federated loss, including both unique gt + classes and sampled negative classes. + """ + unique_gt_classes = torch.unique(gt_classes) + prob = unique_gt_classes.new_ones(num_classes + 1).float() + prob[-1] = 0 + if len(unique_gt_classes) < num_fed_loss_classes: + prob[:num_classes] = weight.float().clone() + prob[unique_gt_classes] = 0 + sampled_negative_classes = torch.multinomial( + prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False + ) + fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes]) + else: + fed_loss_classes = unique_gt_classes + return fed_loss_classes + + # Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py#L113 # noqa + # with slight modifications + def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): + """ + Args: + pred_class_logits: shape (N, K+1), scores for each of the N box. Each row contains the + scores for K object categories and 1 background class + gt_classes: a long tensor of shape R that contains the gt class label of each proposal. + """ + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + N = pred_class_logits.shape[0] + K = pred_class_logits.shape[1] - 1 + + target = pred_class_logits.new_zeros(N, K + 1) + target[range(len(gt_classes)), gt_classes] = 1 + target = target[:, :K] + + cls_loss = F.binary_cross_entropy_with_logits( + pred_class_logits[:, :-1], target, reduction="none" + ) + + if self.use_fed_loss: + fed_loss_classes = self.get_fed_loss_classes( + gt_classes, + num_fed_loss_classes=self.fed_loss_num_classes, + num_classes=K, + weight=self.fed_loss_cls_weights, + ) + fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1) + fed_loss_classes_mask[fed_loss_classes] = 1 + fed_loss_classes_mask = fed_loss_classes_mask[:K] + weight = fed_loss_classes_mask.view(1, K).expand(N, K).float() + else: + weight = 1 + + loss = torch.sum(cls_loss * weight) / N + return loss + + def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes): + """ + Args: + proposal_boxes/gt_boxes are tensors with the same shape (R, 4 or 5). + pred_deltas has shape (R, 4 or 5), or (R, num_classes * (4 or 5)). + gt_classes is a long tensor of shape R, the gt class label of each proposal. + R shall be the number of proposals. + """ + box_dim = proposal_boxes.shape[1] # 4 or 5 + # Regression loss is only computed for foreground proposals (those matched to a GT) + fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < self.num_classes))[0] + if pred_deltas.shape[1] == box_dim: # cls-agnostic regression + fg_pred_deltas = pred_deltas[fg_inds] + else: + fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ + fg_inds, gt_classes[fg_inds] + ] + + loss_box_reg = _dense_box_regression_loss( + [proposal_boxes[fg_inds]], + self.box2box_transform, + [fg_pred_deltas.unsqueeze(0)], + [gt_boxes[fg_inds]], + ..., + self.box_reg_loss_type, + self.smooth_l1_beta, + ) + + # The reg loss is normalized using the total number of regions (R), not the number + # of foreground regions even though the box regression loss is only defined on + # foreground regions. Why? Because doing so gives equal training influence to + # each foreground example. To see how, consider two different minibatches: + # (1) Contains a single foreground region + # (2) Contains 100 foreground regions + # If we normalize by the number of foreground regions, the single example in + # minibatch (1) will be given 100 times as much influence as each foreground + # example in minibatch (2). Normalizing by the total number of regions, R, + # means that the single example in minibatch (1) and each of the 100 examples + # in minibatch (2) are given equal influence. + return loss_box_reg / max(gt_classes.numel(), 1.0) # return 0 if empty + + def inference(self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances]): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. The ``proposal_boxes`` field is expected. + + Returns: + list[Instances]: same as `fast_rcnn_inference`. + list[Tensor]: same as `fast_rcnn_inference`. + """ + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + image_shapes = [x.image_size for x in proposals] + return fast_rcnn_inference( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + def predict_boxes_for_gt_classes(self, predictions, proposals): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were used + to compute predictions. The fields ``proposal_boxes``, ``gt_classes`` are expected. + + Returns: + list[Tensor]: + A list of Tensors of predicted boxes for GT classes in case of + class-specific box head. Element i of the list has shape (Ri, B), where Ri is + the number of proposals for image i and B is the box dimension (4 or 5) + """ + if not len(proposals): + return [] + scores, proposal_deltas = predictions + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) + N, B = proposal_boxes.shape + predict_boxes = self.box2box_transform.apply_deltas( + proposal_deltas, proposal_boxes + ) # Nx(KxB) + + K = predict_boxes.shape[1] // B + if K > 1: + gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0) + # Some proposals are ignored or have a background class. Their gt_classes + # cannot be used as index. + gt_classes = gt_classes.clamp_(0, K - 1) + + predict_boxes = predict_boxes.view(N, K, B)[ + torch.arange(N, dtype=torch.long, device=predict_boxes.device), gt_classes + ] + num_prop_per_image = [len(p) for p in proposals] + return predict_boxes.split(num_prop_per_image) + + def predict_boxes( + self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances] + ): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. The ``proposal_boxes`` field is expected. + + Returns: + list[Tensor]: + A list of Tensors of predicted class-specific or class-agnostic boxes + for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is + the number of proposals for image i and B is the box dimension (4 or 5) + """ + if not len(proposals): + return [] + _, proposal_deltas = predictions + num_prop_per_image = [len(p) for p in proposals] + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) + predict_boxes = self.box2box_transform.apply_deltas( + proposal_deltas, + proposal_boxes, + ) # Nx(KxB) + return predict_boxes.split(num_prop_per_image) + + def predict_probs( + self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances] + ): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. + + Returns: + list[Tensor]: + A list of Tensors of predicted class probabilities for each image. + Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i. + """ + scores, _ = predictions + num_inst_per_image = [len(p) for p in proposals] + if self.use_sigmoid_ce: + probs = scores.sigmoid() + else: + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/keypoint_head.py b/data_processing/detectron2/detectron2/modeling/roi_heads/keypoint_head.py new file mode 100644 index 0000000..e0acc13 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/keypoint_head.py @@ -0,0 +1,272 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ConvTranspose2d, cat, interpolate +from detectron2.structures import Instances, heatmaps_to_keypoints +from detectron2.utils.events import get_event_storage +from detectron2.utils.registry import Registry + +_TOTAL_SKIPPED = 0 + + +__all__ = [ + "ROI_KEYPOINT_HEAD_REGISTRY", + "build_keypoint_head", + "BaseKeypointRCNNHead", + "KRCNNConvDeconvUpsampleHead", +] + + +ROI_KEYPOINT_HEAD_REGISTRY = Registry("ROI_KEYPOINT_HEAD") +ROI_KEYPOINT_HEAD_REGISTRY.__doc__ = """ +Registry for keypoint heads, which make keypoint predictions from per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +def build_keypoint_head(cfg, input_shape): + """ + Build a keypoint head from `cfg.MODEL.ROI_KEYPOINT_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_KEYPOINT_HEAD.NAME + return ROI_KEYPOINT_HEAD_REGISTRY.get(name)(cfg, input_shape) + + +def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): + """ + Arguments: + pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number + of instances in the batch, K is the number of keypoints, and S is the side length + of the keypoint heatmap. The values are spatial logits. + instances (list[Instances]): A list of M Instances, where M is the batch size. + These instances are predictions from the model + that are in 1:1 correspondence with pred_keypoint_logits. + Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint` + instance. + normalizer (float): Normalize the loss by this amount. + If not specified, we normalize by the number of visible keypoints in the minibatch. + + Returns a scalar tensor containing the loss. + """ + heatmaps = [] + valid = [] + + keypoint_side_len = pred_keypoint_logits.shape[2] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + keypoints = instances_per_image.gt_keypoints + heatmaps_per_image, valid_per_image = keypoints.to_heatmap( + instances_per_image.proposal_boxes.tensor, keypoint_side_len + ) + heatmaps.append(heatmaps_per_image.view(-1)) + valid.append(valid_per_image.view(-1)) + + if len(heatmaps): + keypoint_targets = cat(heatmaps, dim=0) + valid = cat(valid, dim=0).to(dtype=torch.uint8) + valid = torch.nonzero(valid).squeeze(1) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if len(heatmaps) == 0 or valid.numel() == 0: + global _TOTAL_SKIPPED + _TOTAL_SKIPPED += 1 + storage = get_event_storage() + storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False) + return pred_keypoint_logits.sum() * 0 + + N, K, H, W = pred_keypoint_logits.shape + pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) + + keypoint_loss = F.cross_entropy( + pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" + ) + + # If a normalizer isn't specified, normalize by the number of visible keypoints in the minibatch + if normalizer is None: + normalizer = valid.numel() + keypoint_loss /= normalizer + + return keypoint_loss + + +def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]): + """ + Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) + and add it to the `pred_instances` as a `pred_keypoints` field. + + Args: + pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number + of instances in the batch, K is the number of keypoints, and S is the side length of + the keypoint heatmap. The values are spatial logits. + pred_instances (list[Instances]): A list of N Instances, where N is the number of images. + + Returns: + None. Each element in pred_instances will contain extra "pred_keypoints" and + "pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape + (#instance, K, 3) where the last dimension corresponds to (x, y, score). + The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw + keypoint logits as passed to this function. + """ + # flatten all bboxes from all images together (list[Boxes] -> Rx4 tensor) + bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) + + pred_keypoint_logits = pred_keypoint_logits.detach() + keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach()) + num_instances_per_image = [len(i) for i in pred_instances] + keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0) + heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0) + + for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip( + keypoint_results, heatmap_results, pred_instances + ): + # keypoint_results_per_image is (num instances)x(num keypoints)x(x, y, score) + # heatmap_results_per_image is (num instances)x(num keypoints)x(side)x(side) + instances_per_image.pred_keypoints = keypoint_results_per_image + instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image + + +class BaseKeypointRCNNHead(nn.Module): + """ + Implement the basic Keypoint R-CNN losses and inference logic described in + Sec. 5 of :paper:`Mask R-CNN`. + """ + + @configurable + def __init__(self, *, num_keypoints, loss_weight=1.0, loss_normalizer=1.0): + """ + NOTE: this interface is experimental. + + Args: + num_keypoints (int): number of keypoints to predict + loss_weight (float): weight to multiple on the keypoint loss + loss_normalizer (float or str): + If float, divide the loss by `loss_normalizer * #images`. + If 'visible', the loss is normalized by the total number of + visible keypoints across images. + """ + super().__init__() + self.num_keypoints = num_keypoints + self.loss_weight = loss_weight + assert loss_normalizer == "visible" or isinstance(loss_normalizer, float), loss_normalizer + self.loss_normalizer = loss_normalizer + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + "loss_weight": cfg.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT, + "num_keypoints": cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS, + } + normalize_by_visible = ( + cfg.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS + ) # noqa + if not normalize_by_visible: + batch_size_per_image = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE + positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ret["loss_normalizer"] = ( + ret["num_keypoints"] * batch_size_per_image * positive_sample_fraction + ) + else: + ret["loss_normalizer"] = "visible" + return ret + + def forward(self, x, instances: List[Instances]): + """ + Args: + x: input 4D region feature(s) provided by :class:`ROIHeads`. + instances (list[Instances]): contains the boxes & labels corresponding + to the input features. + Exact format is up to its caller to decide. + Typically, this is the foreground instances in training, with + "proposal_boxes" field and other gt annotations. + In inference, it contains boxes that are already predicted. + + Returns: + A dict of losses if in training. The predicted "instances" if in inference. + """ + x = self.layers(x) + if self.training: + num_images = len(instances) + normalizer = ( + None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer + ) + return { + "loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer) + * self.loss_weight + } + else: + keypoint_rcnn_inference(x, instances) + return instances + + def layers(self, x): + """ + Neural network layers that makes predictions from regional input features. + """ + raise NotImplementedError + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_KEYPOINT_HEAD_REGISTRY.register() +class KRCNNConvDeconvUpsampleHead(BaseKeypointRCNNHead, nn.Sequential): + """ + A standard keypoint head containing a series of 3x3 convs, followed by + a transpose convolution and bilinear interpolation for upsampling. + It is described in Sec. 5 of :paper:`Mask R-CNN`. + """ + + @configurable + def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + conv_dims: an iterable of output channel counts for each conv in the head + e.g. (512, 512, 512) for three convs outputting 512 channels. + """ + super().__init__(num_keypoints=num_keypoints, **kwargs) + + # default up_scale to 2.0 (this can be made an option) + up_scale = 2.0 + in_channels = input_shape.channels + + for idx, layer_channels in enumerate(conv_dims, 1): + module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1) + self.add_module("conv_fcn{}".format(idx), module) + self.add_module("conv_fcn_relu{}".format(idx), nn.ReLU()) + in_channels = layer_channels + + deconv_kernel = 4 + self.score_lowres = ConvTranspose2d( + in_channels, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1 + ) + self.up_scale = up_scale + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret["input_shape"] = input_shape + ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS + return ret + + def layers(self, x): + for layer in self: + x = layer(x) + x = interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False) + return x diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/mask_head.py b/data_processing/detectron2/detectron2/modeling/roi_heads/mask_head.py new file mode 100644 index 0000000..1eff8f7 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/mask_head.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm +from detectron2.layers.wrappers import move_device_like +from detectron2.structures import Instances +from detectron2.utils.events import get_event_storage +from detectron2.utils.registry import Registry + +__all__ = [ + "BaseMaskRCNNHead", + "MaskRCNNConvUpsampleHead", + "build_mask_head", + "ROI_MASK_HEAD_REGISTRY", +] + + +ROI_MASK_HEAD_REGISTRY = Registry("ROI_MASK_HEAD") +ROI_MASK_HEAD_REGISTRY.__doc__ = """ +Registry for mask heads, which predicts instance masks given +per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +@torch.jit.unused +def mask_rcnn_loss(pred_mask_logits: torch.Tensor, instances: List[Instances], vis_period: int = 0): + """ + Compute the mask prediction loss defined in the Mask R-CNN paper. + + Args: + pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask) + for class-specific or class-agnostic, where B is the total number of predicted masks + in all images, C is the number of foreground classes, and Hmask, Wmask are the height + and width of the mask predictions. The values are logits. + instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. These instances are in 1:1 + correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, + ...) associated with each instance are stored in fields. + vis_period (int): the period (in steps) to dump visualization. + + Returns: + mask_loss (Tensor): A scalar tensor containing the loss. + """ + cls_agnostic_mask = pred_mask_logits.size(1) == 1 + total_num_masks = pred_mask_logits.size(0) + mask_side_len = pred_mask_logits.size(2) + assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!" + + gt_classes = [] + gt_masks = [] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + if not cls_agnostic_mask: + gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) + gt_classes.append(gt_classes_per_image) + + gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize( + instances_per_image.proposal_boxes.tensor, mask_side_len + ).to(device=pred_mask_logits.device) + # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len + gt_masks.append(gt_masks_per_image) + + if len(gt_masks) == 0: + return pred_mask_logits.sum() * 0 + + gt_masks = cat(gt_masks, dim=0) + + if cls_agnostic_mask: + pred_mask_logits = pred_mask_logits[:, 0] + else: + indices = torch.arange(total_num_masks) + gt_classes = cat(gt_classes, dim=0) + pred_mask_logits = pred_mask_logits[indices, gt_classes] + + if gt_masks.dtype == torch.bool: + gt_masks_bool = gt_masks + else: + # Here we allow gt_masks to be float as well (depend on the implementation of rasterize()) + gt_masks_bool = gt_masks > 0.5 + gt_masks = gt_masks.to(dtype=torch.float32) + + # Log the training accuracy (using gt classes and sigmoid(0.0) == 0.5 threshold) + mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool + mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0)) + num_positive = gt_masks_bool.sum().item() + false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max( + gt_masks_bool.numel() - num_positive, 1.0 + ) + false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0) + + storage = get_event_storage() + storage.put_scalar("mask_rcnn/accuracy", mask_accuracy) + storage.put_scalar("mask_rcnn/false_positive", false_positive) + storage.put_scalar("mask_rcnn/false_negative", false_negative) + if vis_period > 0 and storage.iter % vis_period == 0: + pred_masks = pred_mask_logits.sigmoid() + vis_masks = torch.cat([pred_masks, gt_masks], axis=2) + name = "Left: mask prediction; Right: mask GT" + for idx, vis_mask in enumerate(vis_masks): + vis_mask = torch.stack([vis_mask] * 3, axis=0) + storage.put_image(name + f" ({idx})", vis_mask) + + mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean") + return mask_loss + + +def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Instances]): + """ + Convert pred_mask_logits to estimated foreground probability masks while also + extracting only the masks for the predicted classes in pred_instances. For each + predicted box, the mask of the same class is attached to the instance by adding a + new "pred_masks" field to pred_instances. + + Args: + pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask) + for class-specific or class-agnostic, where B is the total number of predicted masks + in all images, C is the number of foreground classes, and Hmask, Wmask are the height + and width of the mask predictions. The values are logits. + pred_instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. Each Instances must have field "pred_classes". + + Returns: + None. pred_instances will contain an extra "pred_masks" field storing a mask of size (Hmask, + Wmask) for predicted class. Note that the masks are returned as a soft (non-quantized) + masks the resolution predicted by the network; post-processing steps, such as resizing + the predicted masks to the original image resolution and/or binarizing them, is left + to the caller. + """ + cls_agnostic_mask = pred_mask_logits.size(1) == 1 + + if cls_agnostic_mask: + mask_probs_pred = pred_mask_logits.sigmoid() + else: + # Select masks corresponding to the predicted classes + num_masks = pred_mask_logits.shape[0] + class_pred = cat([i.pred_classes for i in pred_instances]) + device = ( + class_pred.device + if torch.jit.is_scripting() + else ("cpu" if torch.jit.is_tracing() else class_pred.device) + ) + indices = move_device_like(torch.arange(num_masks, device=device), class_pred) + mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid() + # mask_probs_pred.shape: (B, 1, Hmask, Wmask) + + num_boxes_per_image = [len(i) for i in pred_instances] + mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0) + + for prob, instances in zip(mask_probs_pred, pred_instances): + instances.pred_masks = prob # (1, Hmask, Wmask) + + +class BaseMaskRCNNHead(nn.Module): + """ + Implement the basic Mask R-CNN losses and inference logic described in :paper:`Mask R-CNN` + """ + + @configurable + def __init__(self, *, loss_weight: float = 1.0, vis_period: int = 0): + """ + NOTE: this interface is experimental. + + Args: + loss_weight (float): multiplier of the loss + vis_period (int): visualization period + """ + super().__init__() + self.vis_period = vis_period + self.loss_weight = loss_weight + + @classmethod + def from_config(cls, cfg, input_shape): + return {"vis_period": cfg.VIS_PERIOD} + + def forward(self, x, instances: List[Instances]): + """ + Args: + x: input region feature(s) provided by :class:`ROIHeads`. + instances (list[Instances]): contains the boxes & labels corresponding + to the input features. + Exact format is up to its caller to decide. + Typically, this is the foreground instances in training, with + "proposal_boxes" field and other gt annotations. + In inference, it contains boxes that are already predicted. + + Returns: + A dict of losses in training. The predicted "instances" in inference. + """ + x = self.layers(x) + if self.training: + return {"loss_mask": mask_rcnn_loss(x, instances, self.vis_period) * self.loss_weight} + else: + mask_rcnn_inference(x, instances) + return instances + + def layers(self, x): + """ + Neural network layers that makes predictions from input features. + """ + raise NotImplementedError + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_MASK_HEAD_REGISTRY.register() +class MaskRCNNConvUpsampleHead(BaseMaskRCNNHead, nn.Sequential): + """ + A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). + Predictions are made with a final 1x1 conv layer. + """ + + @configurable + def __init__(self, input_shape: ShapeSpec, *, num_classes, conv_dims, conv_norm="", **kwargs): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + num_classes (int): the number of foreground classes (i.e. background is not + included). 1 if using class agnostic prediction. + conv_dims (list[int]): a list of N>0 integers representing the output dimensions + of N-1 conv layers and the last upsample layer. + conv_norm (str or callable): normalization for the conv layers. + See :func:`detectron2.layers.get_norm` for supported types. + """ + super().__init__(**kwargs) + assert len(conv_dims) >= 1, "conv_dims have to be non-empty!" + + self.conv_norm_relus = [] + + cur_channels = input_shape.channels + for k, conv_dim in enumerate(conv_dims[:-1]): + conv = Conv2d( + cur_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=not conv_norm, + norm=get_norm(conv_norm, conv_dim), + activation=nn.ReLU(), + ) + self.add_module("mask_fcn{}".format(k + 1), conv) + self.conv_norm_relus.append(conv) + cur_channels = conv_dim + + self.deconv = ConvTranspose2d( + cur_channels, conv_dims[-1], kernel_size=2, stride=2, padding=0 + ) + self.add_module("deconv_relu", nn.ReLU()) + cur_channels = conv_dims[-1] + + self.predictor = Conv2d(cur_channels, num_classes, kernel_size=1, stride=1, padding=0) + + for layer in self.conv_norm_relus + [self.deconv]: + weight_init.c2_msra_fill(layer) + # use normal distribution initialization for mask prediction layer + nn.init.normal_(self.predictor.weight, std=0.001) + if self.predictor.bias is not None: + nn.init.constant_(self.predictor.bias, 0) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM + num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV + ret.update( + conv_dims=[conv_dim] * (num_conv + 1), # +1 for ConvTranspose + conv_norm=cfg.MODEL.ROI_MASK_HEAD.NORM, + input_shape=input_shape, + ) + if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: + ret["num_classes"] = 1 + else: + ret["num_classes"] = cfg.MODEL.ROI_HEADS.NUM_CLASSES + return ret + + def layers(self, x): + for layer in self: + x = layer(x) + return x + + +def build_mask_head(cfg, input_shape): + """ + Build a mask head defined by `cfg.MODEL.ROI_MASK_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_MASK_HEAD.NAME + return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/roi_heads.py b/data_processing/detectron2/detectron2/modeling/roi_heads/roi_heads.py new file mode 100644 index 0000000..13dd57a --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/roi_heads.py @@ -0,0 +1,877 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import inspect +import logging +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import nn + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, nonzero_tuple +from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from detectron2.utils.events import get_event_storage +from detectron2.utils.registry import Registry + +from ..backbone.resnet import BottleneckBlock, ResNet +from ..matcher import Matcher +from ..poolers import ROIPooler +from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals +from ..sampling import subsample_labels +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers +from .keypoint_head import build_keypoint_head +from .mask_head import build_mask_head + +ROI_HEADS_REGISTRY = Registry("ROI_HEADS") +ROI_HEADS_REGISTRY.__doc__ = """ +Registry for ROI heads in a generalized R-CNN model. +ROIHeads take feature maps and region proposals, and +perform per-region computation. + +The registered object will be called with `obj(cfg, input_shape)`. +The call is expected to return an :class:`ROIHeads`. +""" + +logger = logging.getLogger(__name__) + + +def build_roi_heads(cfg, input_shape): + """ + Build ROIHeads defined by `cfg.MODEL.ROI_HEADS.NAME`. + """ + name = cfg.MODEL.ROI_HEADS.NAME + return ROI_HEADS_REGISTRY.get(name)(cfg, input_shape) + + +def select_foreground_proposals( + proposals: List[Instances], bg_label: int +) -> Tuple[List[Instances], List[torch.Tensor]]: + """ + Given a list of N Instances (for N images), each containing a `gt_classes` field, + return a list of Instances that contain only instances with `gt_classes != -1 && + gt_classes != bg_label`. + + Args: + proposals (list[Instances]): A list of N Instances, where N is the number of + images in the batch. + bg_label: label index of background class. + + Returns: + list[Instances]: N Instances, each contains only the selected foreground instances. + list[Tensor]: N boolean vector, correspond to the selection mask of + each Instances object. True for selected instances. + """ + assert isinstance(proposals, (list, tuple)) + assert isinstance(proposals[0], Instances) + assert proposals[0].has("gt_classes") + fg_proposals = [] + fg_selection_masks = [] + for proposals_per_image in proposals: + gt_classes = proposals_per_image.gt_classes + fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label) + fg_idxs = fg_selection_mask.nonzero().squeeze(1) + fg_proposals.append(proposals_per_image[fg_idxs]) + fg_selection_masks.append(fg_selection_mask) + return fg_proposals, fg_selection_masks + + +def select_proposals_with_visible_keypoints(proposals: List[Instances]) -> List[Instances]: + """ + Args: + proposals (list[Instances]): a list of N Instances, where N is the + number of images. + + Returns: + proposals: only contains proposals with at least one visible keypoint. + + Note that this is still slightly different from Detectron. + In Detectron, proposals for training keypoint head are re-sampled from + all the proposals with IOU>threshold & >=1 visible keypoint. + + Here, the proposals are first sampled from all proposals with + IOU>threshold, then proposals with no visible keypoint are filtered out. + This strategy seems to make no difference on Detectron and is easier to implement. + """ + ret = [] + all_num_fg = [] + for proposals_per_image in proposals: + # If empty/unannotated image (hard negatives), skip filtering for train + if len(proposals_per_image) == 0: + ret.append(proposals_per_image) + continue + gt_keypoints = proposals_per_image.gt_keypoints.tensor + # #fg x K x 3 + vis_mask = gt_keypoints[:, :, 2] >= 1 + xs, ys = gt_keypoints[:, :, 0], gt_keypoints[:, :, 1] + proposal_boxes = proposals_per_image.proposal_boxes.tensor.unsqueeze(dim=1) # #fg x 1 x 4 + kp_in_box = ( + (xs >= proposal_boxes[:, :, 0]) + & (xs <= proposal_boxes[:, :, 2]) + & (ys >= proposal_boxes[:, :, 1]) + & (ys <= proposal_boxes[:, :, 3]) + ) + selection = (kp_in_box & vis_mask).any(dim=1) + selection_idxs = nonzero_tuple(selection)[0] + all_num_fg.append(selection_idxs.numel()) + ret.append(proposals_per_image[selection_idxs]) + + storage = get_event_storage() + storage.put_scalar("keypoint_head/num_fg_samples", np.mean(all_num_fg)) + return ret + + +class ROIHeads(torch.nn.Module): + """ + ROIHeads perform all per-region computation in an R-CNN. + + It typically contains logic to + + 1. (in training only) match proposals with ground truth and sample them + 2. crop the regions and extract per-region features using proposals + 3. make per-region predictions with different heads + + It can have many variants, implemented as subclasses of this class. + This base class contains the logic to match/sample proposals. + But it is not necessary to inherit this class if the sampling logic is not needed. + """ + + @configurable + def __init__( + self, + *, + num_classes, + batch_size_per_image, + positive_fraction, + proposal_matcher, + proposal_append_gt=True, + ): + """ + NOTE: this interface is experimental. + + Args: + num_classes (int): number of foreground classes (i.e. background is not included) + batch_size_per_image (int): number of proposals to sample for training + positive_fraction (float): fraction of positive (foreground) proposals + to sample for training. + proposal_matcher (Matcher): matcher that matches proposals and ground truth + proposal_append_gt (bool): whether to include ground truth as proposals as well + """ + super().__init__() + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + self.num_classes = num_classes + self.proposal_matcher = proposal_matcher + self.proposal_append_gt = proposal_append_gt + + @classmethod + def from_config(cls, cfg): + return { + "batch_size_per_image": cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, + "positive_fraction": cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION, + "num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES, + "proposal_append_gt": cfg.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT, + # Matcher to assign box proposals to gt boxes + "proposal_matcher": Matcher( + cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS, + cfg.MODEL.ROI_HEADS.IOU_LABELS, + allow_low_quality_matches=False, + ), + } + + def _sample_proposals( + self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Based on the matching between N proposals and M groundtruth, + sample the proposals and set their classification labels. + + Args: + matched_idxs (Tensor): a vector of length N, each is the best-matched + gt index in [0, M) for each proposal. + matched_labels (Tensor): a vector of length N, the matcher's label + (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal. + gt_classes (Tensor): a vector of length M. + + Returns: + Tensor: a vector of indices of sampled proposals. Each is in [0, N). + Tensor: a vector of the same length, the classification label for + each sampled proposal. Each sample is labeled as either a category in + [0, num_classes) or the background (num_classes). + """ + has_gt = gt_classes.numel() > 0 + # Get the corresponding GT for each proposal + if has_gt: + gt_classes = gt_classes[matched_idxs] + # Label unmatched proposals (0 label from matcher) as background (label=num_classes) + gt_classes[matched_labels == 0] = self.num_classes + # Label ignore proposals (-1 label) + gt_classes[matched_labels == -1] = -1 + else: + gt_classes = torch.zeros_like(matched_idxs) + self.num_classes + + sampled_fg_idxs, sampled_bg_idxs = subsample_labels( + gt_classes, self.batch_size_per_image, self.positive_fraction, self.num_classes + ) + + sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) + return sampled_idxs, gt_classes[sampled_idxs] + + @torch.no_grad() + def label_and_sample_proposals( + self, proposals: List[Instances], targets: List[Instances] + ) -> List[Instances]: + """ + Prepare some proposals to be used to train the ROI heads. + It performs box matching between `proposals` and `targets`, and assigns + training labels to the proposals. + It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth + boxes, with a fraction of positives that is no larger than + ``self.positive_fraction``. + + Args: + See :meth:`ROIHeads.forward` + + Returns: + list[Instances]: + length `N` list of `Instances`s containing the proposals + sampled for training. Each `Instances` has the following fields: + + - proposal_boxes: the proposal boxes + - gt_boxes: the ground-truth box that the proposal is assigned to + (this is only meaningful if the proposal has a label > 0; if label = 0 + then the ground-truth box is random) + + Other fields such as "gt_classes", "gt_masks", that's included in `targets`. + """ + # Augment proposals with ground-truth boxes. + # In the case of learned proposals (e.g., RPN), when training starts + # the proposals will be low quality due to random initialization. + # It's possible that none of these initial + # proposals have high enough overlap with the gt objects to be used + # as positive examples for the second stage components (box head, + # cls head, mask head). Adding the gt boxes to the set of proposals + # ensures that the second stage components will have some positive + # examples from the start of training. For RPN, this augmentation improves + # convergence and empirically improves box AP on COCO by about 0.5 + # points (under one tested configuration). + if self.proposal_append_gt: + proposals = add_ground_truth_to_proposals(targets, proposals) + + proposals_with_gt = [] + + num_fg_samples = [] + num_bg_samples = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + has_gt = len(targets_per_image) > 0 + match_quality_matrix = pairwise_iou( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix) + sampled_idxs, gt_classes = self._sample_proposals( + matched_idxs, matched_labels, targets_per_image.gt_classes + ) + + # Set target attributes of the sampled proposals: + proposals_per_image = proposals_per_image[sampled_idxs] + proposals_per_image.gt_classes = gt_classes + + if has_gt: + sampled_targets = matched_idxs[sampled_idxs] + # We index all the attributes of targets that start with "gt_" + # and have not been added to proposals yet (="gt_classes"). + # NOTE: here the indexing waste some compute, because heads + # like masks, keypoints, etc, will filter the proposals again, + # (by foreground/background, or number of keypoints in the image, etc) + # so we essentially index the data twice. + for (trg_name, trg_value) in targets_per_image.get_fields().items(): + if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name): + proposals_per_image.set(trg_name, trg_value[sampled_targets]) + # If no GT is given in the image, we don't know what a dummy gt value can be. + # Therefore the returned proposals won't have any gt_* fields, except for a + # gt_classes full of background label. + + num_bg_samples.append((gt_classes == self.num_classes).sum().item()) + num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) + proposals_with_gt.append(proposals_per_image) + + # Log the number of fg/bg samples that are selected for training ROI heads + storage = get_event_storage() + storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples)) + storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples)) + + return proposals_with_gt + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ) -> Tuple[List[Instances], Dict[str, torch.Tensor]]: + """ + Args: + images (ImageList): + features (dict[str,Tensor]): input data as a mapping from feature + map name to tensor. Axis 0 represents the number of images `N` in + the input data; axes 1-3 are channels, height, and width, which may + vary between feature maps (e.g., if a feature pyramid is used). + proposals (list[Instances]): length `N` list of `Instances`. The i-th + `Instances` contains object proposals for the i-th input image, + with fields "proposal_boxes" and "objectness_logits". + targets (list[Instances], optional): length `N` list of `Instances`. The i-th + `Instances` contains the ground-truth per-instance annotations + for the i-th input image. Specify `targets` during training only. + It may have the following fields: + + - gt_boxes: the bounding box of each instance. + - gt_classes: the label for each instance with a category ranging in [0, #class]. + - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. + - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. + + Returns: + list[Instances]: length `N` list of `Instances` containing the + detected instances. Returned during inference only; may be [] during training. + + dict[str->Tensor]: + mapping from a named loss to a tensor storing the loss. Used during training only. + """ + raise NotImplementedError() + + +@ROI_HEADS_REGISTRY.register() +class Res5ROIHeads(ROIHeads): + """ + The ROIHeads in a typical "C4" R-CNN model, where + the box and mask head share the cropping and + the per-region feature computation by a Res5 block. + See :paper:`ResNet` Appendix A. + """ + + @configurable + def __init__( + self, + *, + in_features: List[str], + pooler: ROIPooler, + res5: nn.Module, + box_predictor: nn.Module, + mask_head: Optional[nn.Module] = None, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + in_features (list[str]): list of backbone feature map names to use for + feature extraction + pooler (ROIPooler): pooler to extra region features from backbone + res5 (nn.Sequential): a CNN to compute per-region features, to be used by + ``box_predictor`` and ``mask_head``. Typically this is a "res5" + block from a ResNet. + box_predictor (nn.Module): make box predictions from the feature. + Should have the same interface as :class:`FastRCNNOutputLayers`. + mask_head (nn.Module): transform features to make mask predictions + """ + super().__init__(**kwargs) + self.in_features = in_features + self.pooler = pooler + if isinstance(res5, (list, tuple)): + res5 = nn.Sequential(*res5) + self.res5 = res5 + self.box_predictor = box_predictor + self.mask_on = mask_head is not None + if self.mask_on: + self.mask_head = mask_head + + @classmethod + def from_config(cls, cfg, input_shape): + # fmt: off + ret = super().from_config(cfg) + in_features = ret["in_features"] = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + pooler_scales = (1.0 / input_shape[in_features[0]].stride, ) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + mask_on = cfg.MODEL.MASK_ON + # fmt: on + assert not cfg.MODEL.KEYPOINT_ON + assert len(in_features) == 1 + + ret["pooler"] = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + + # Compatbility with old moco code. Might be useful. + # See notes in StandardROIHeads.from_config + if not inspect.ismethod(cls._build_res5_block): + logger.warning( + "The behavior of _build_res5_block may change. " + "Please do not depend on private methods." + ) + cls._build_res5_block = classmethod(cls._build_res5_block) + + ret["res5"], out_channels = cls._build_res5_block(cfg) + ret["box_predictor"] = FastRCNNOutputLayers( + cfg, ShapeSpec(channels=out_channels, height=1, width=1) + ) + + if mask_on: + ret["mask_head"] = build_mask_head( + cfg, + ShapeSpec(channels=out_channels, width=pooler_resolution, height=pooler_resolution), + ) + return ret + + @classmethod + def _build_res5_block(cls, cfg): + # fmt: off + stage_channel_factor = 2 ** 3 # res5 is 8x res2 + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group * stage_channel_factor + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + norm = cfg.MODEL.RESNETS.NORM + assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \ + "Deformable conv is not yet supported in res5 head." + # fmt: on + + blocks = ResNet.make_stage( + BottleneckBlock, + 3, + stride_per_block=[2, 1, 1], + in_channels=out_channels // 2, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + norm=norm, + stride_in_1x1=stride_in_1x1, + ) + return nn.Sequential(*blocks), out_channels + + def _shared_roi_transform(self, features: List[torch.Tensor], boxes: List[Boxes]): + x = self.pooler(features, boxes) + return self.res5(x) + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ): + """ + See :meth:`ROIHeads.forward`. + """ + del images + + if self.training: + assert targets + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + proposal_boxes = [x.proposal_boxes for x in proposals] + box_features = self._shared_roi_transform( + [features[f] for f in self.in_features], proposal_boxes + ) + predictions = self.box_predictor(box_features.mean(dim=[2, 3])) + + if self.training: + del features + losses = self.box_predictor.losses(predictions, proposals) + if self.mask_on: + proposals, fg_selection_masks = select_foreground_proposals( + proposals, self.num_classes + ) + # Since the ROI feature transform is shared between boxes and masks, + # we don't need to recompute features. The mask loss is only defined + # on foreground proposals, so we need to select out the foreground + # features. + mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] + del box_features + losses.update(self.mask_head(mask_features, proposals)) + return [], losses + else: + pred_instances, _ = self.box_predictor.inference(predictions, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def forward_with_given_boxes( + self, features: Dict[str, torch.Tensor], instances: List[Instances] + ) -> List[Instances]: + """ + Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. + + Args: + features: same as in `forward()` + instances (list[Instances]): instances to predict other outputs. Expect the keys + "pred_boxes" and "pred_classes" to exist. + + Returns: + instances (Instances): + the same `Instances` object, with extra + fields such as `pred_masks` or `pred_keypoints`. + """ + assert not self.training + assert instances[0].has("pred_boxes") and instances[0].has("pred_classes") + + if self.mask_on: + feature_list = [features[f] for f in self.in_features] + x = self._shared_roi_transform(feature_list, [x.pred_boxes for x in instances]) + return self.mask_head(x, instances) + else: + return instances + + +@ROI_HEADS_REGISTRY.register() +class StandardROIHeads(ROIHeads): + """ + It's "standard" in a sense that there is no ROI transform sharing + or feature sharing between tasks. + Each head independently processes the input features by each head's + own pooler and head. + + This class is used by most models, such as FPN and C5. + To implement more models, you can subclass it and implement a different + :meth:`forward()` or a head. + """ + + @configurable + def __init__( + self, + *, + box_in_features: List[str], + box_pooler: ROIPooler, + box_head: nn.Module, + box_predictor: nn.Module, + mask_in_features: Optional[List[str]] = None, + mask_pooler: Optional[ROIPooler] = None, + mask_head: Optional[nn.Module] = None, + keypoint_in_features: Optional[List[str]] = None, + keypoint_pooler: Optional[ROIPooler] = None, + keypoint_head: Optional[nn.Module] = None, + train_on_pred_boxes: bool = False, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + box_in_features (list[str]): list of feature names to use for the box head. + box_pooler (ROIPooler): pooler to extra region features for box head + box_head (nn.Module): transform features to make box predictions + box_predictor (nn.Module): make box predictions from the feature. + Should have the same interface as :class:`FastRCNNOutputLayers`. + mask_in_features (list[str]): list of feature names to use for the mask + pooler or mask head. None if not using mask head. + mask_pooler (ROIPooler): pooler to extract region features from image features. + The mask head will then take region features to make predictions. + If None, the mask head will directly take the dict of image features + defined by `mask_in_features` + mask_head (nn.Module): transform features to make mask predictions + keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``. + train_on_pred_boxes (bool): whether to use proposal boxes or + predicted boxes from the box head to train other heads. + """ + super().__init__(**kwargs) + # keep self.in_features for backward compatibility + self.in_features = self.box_in_features = box_in_features + self.box_pooler = box_pooler + self.box_head = box_head + self.box_predictor = box_predictor + + self.mask_on = mask_in_features is not None + if self.mask_on: + self.mask_in_features = mask_in_features + self.mask_pooler = mask_pooler + self.mask_head = mask_head + + self.keypoint_on = keypoint_in_features is not None + if self.keypoint_on: + self.keypoint_in_features = keypoint_in_features + self.keypoint_pooler = keypoint_pooler + self.keypoint_head = keypoint_head + + self.train_on_pred_boxes = train_on_pred_boxes + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg) + ret["train_on_pred_boxes"] = cfg.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES + # Subclasses that have not been updated to use from_config style construction + # may have overridden _init_*_head methods. In this case, those overridden methods + # will not be classmethods and we need to avoid trying to call them here. + # We test for this with ismethod which only returns True for bound methods of cls. + # Such subclasses will need to handle calling their overridden _init_*_head methods. + if inspect.ismethod(cls._init_box_head): + ret.update(cls._init_box_head(cfg, input_shape)) + if inspect.ismethod(cls._init_mask_head): + ret.update(cls._init_mask_head(cfg, input_shape)) + if inspect.ismethod(cls._init_keypoint_head): + ret.update(cls._init_keypoint_head(cfg, input_shape)) + return ret + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + # fmt: on + + # If StandardROIHeads is applied on multiple feature maps (as in FPN), + # then we share the same predictors and therefore the channel counts must be the same + in_channels = [input_shape[f].channels for f in in_features] + # Check all channel counts are equal + assert len(set(in_channels)) == 1, in_channels + in_channels = in_channels[0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + # Here we split "box head" and "box predictor", which is mainly due to historical reasons. + # They are used together so the "box predictor" layers should be part of the "box head". + # New subclasses of ROIHeads do not need "box predictor"s. + box_head = build_box_head( + cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution) + ) + box_predictor = FastRCNNOutputLayers(cfg, box_head.output_shape) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_head": box_head, + "box_predictor": box_predictor, + } + + @classmethod + def _init_mask_head(cls, cfg, input_shape): + if not cfg.MODEL.MASK_ON: + return {} + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features][0] + + ret = {"mask_in_features": in_features} + ret["mask_pooler"] = ( + ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + if pooler_type + else None + ) + if pooler_type: + shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + else: + shape = {f: input_shape[f] for f in in_features} + ret["mask_head"] = build_mask_head(cfg, shape) + return ret + + @classmethod + def _init_keypoint_head(cls, cfg, input_shape): + if not cfg.MODEL.KEYPOINT_ON: + return {} + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) # noqa + sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features][0] + + ret = {"keypoint_in_features": in_features} + ret["keypoint_pooler"] = ( + ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + if pooler_type + else None + ) + if pooler_type: + shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + else: + shape = {f: input_shape[f] for f in in_features} + ret["keypoint_head"] = build_keypoint_head(cfg, shape) + return ret + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ) -> Tuple[List[Instances], Dict[str, torch.Tensor]]: + """ + See :class:`ROIHeads.forward`. + """ + del images + if self.training: + assert targets, "'targets' argument is required during training" + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + if self.training: + losses = self._forward_box(features, proposals) + # Usually the original proposals used by the box head are used by the mask, keypoint + # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes + # predicted by the box head. + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + # During inference cascaded prediction is used: the mask and keypoints heads are only + # applied to the top scoring box detections. + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def forward_with_given_boxes( + self, features: Dict[str, torch.Tensor], instances: List[Instances] + ) -> List[Instances]: + """ + Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. + + This is useful for downstream tasks where a box is known, but need to obtain + other attributes (outputs of other heads). + Test-time augmentation also uses this. + + Args: + features: same as in `forward()` + instances (list[Instances]): instances to predict other outputs. Expect the keys + "pred_boxes" and "pred_classes" to exist. + + Returns: + list[Instances]: + the same `Instances` objects, with extra + fields such as `pred_masks` or `pred_keypoints`. + """ + assert not self.training + assert instances[0].has("pred_boxes") and instances[0].has("pred_classes") + + instances = self._forward_mask(features, instances) + instances = self._forward_keypoint(features, instances) + return instances + + def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]): + """ + Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`, + the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + proposals (list[Instances]): the per-image object proposals with + their matching ground truth. + Each has fields "proposal_boxes", and "objectness_logits", + "gt_classes", "gt_boxes". + + Returns: + In training, a dict of losses. + In inference, a list of `Instances`, the predicted instances. + """ + features = [features[f] for f in self.box_in_features] + box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) + box_features = self.box_head(box_features) + predictions = self.box_predictor(box_features) + del box_features + + if self.training: + losses = self.box_predictor.losses(predictions, proposals) + # proposals is modified in-place below, so losses must be computed first. + if self.train_on_pred_boxes: + with torch.no_grad(): + pred_boxes = self.box_predictor.predict_boxes_for_gt_classes( + predictions, proposals + ) + for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes): + proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image) + return losses + else: + pred_instances, _ = self.box_predictor.inference(predictions, proposals) + return pred_instances + + def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + """ + Forward logic of the mask prediction branch. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + instances (list[Instances]): the per-image instances to train/predict masks. + In training, they can be the proposals. + In inference, they can be the boxes predicted by R-CNN box head. + + Returns: + In training, a dict of losses. + In inference, update `instances` with new fields "pred_masks" and return it. + """ + if not self.mask_on: + return {} if self.training else instances + + if self.training: + # head is only trained on positive proposals. + instances, _ = select_foreground_proposals(instances, self.num_classes) + + if self.mask_pooler is not None: + features = [features[f] for f in self.mask_in_features] + boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances] + features = self.mask_pooler(features, boxes) + else: + features = {f: features[f] for f in self.mask_in_features} + return self.mask_head(features, instances) + + def _forward_keypoint(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + """ + Forward logic of the keypoint prediction branch. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + instances (list[Instances]): the per-image instances to train/predict keypoints. + In training, they can be the proposals. + In inference, they can be the boxes predicted by R-CNN box head. + + Returns: + In training, a dict of losses. + In inference, update `instances` with new fields "pred_keypoints" and return it. + """ + if not self.keypoint_on: + return {} if self.training else instances + + if self.training: + # head is only trained on positive proposals with >=1 visible keypoints. + instances, _ = select_foreground_proposals(instances, self.num_classes) + instances = select_proposals_with_visible_keypoints(instances) + + if self.keypoint_pooler is not None: + features = [features[f] for f in self.keypoint_in_features] + boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances] + features = self.keypoint_pooler(features, boxes) + else: + features = {f: features[f] for f in self.keypoint_in_features} + return self.keypoint_head(features, instances) diff --git a/data_processing/detectron2/detectron2/modeling/roi_heads/rotated_fast_rcnn.py b/data_processing/detectron2/detectron2/modeling/roi_heads/rotated_fast_rcnn.py new file mode 100644 index 0000000..1e7bfab --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/roi_heads/rotated_fast_rcnn.py @@ -0,0 +1,271 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +import torch + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, batched_nms_rotated +from detectron2.structures import Instances, RotatedBoxes, pairwise_iou_rotated +from detectron2.utils.events import get_event_storage + +from ..box_regression import Box2BoxTransformRotated +from ..poolers import ROIPooler +from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads + +logger = logging.getLogger(__name__) + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + R: number of ROIs, combined over all images, in the minibatch + Ri: number of ROIs in image i + K: number of foreground classes. E.g.,there are 80 foreground classes in COCO. + +Naming convention: + + deltas: refers to the 5-d (dx, dy, dw, dh, da) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransformRotated`). + + pred_class_logits: predicted class scores in [-inf, +inf]; use + softmax(pred_class_logits) to estimate P(class). + + gt_classes: ground-truth classification labels in [0, K], where [0, K) represent + foreground object classes and K represents the background class. + + pred_proposal_deltas: predicted rotated box2box transform deltas for transforming proposals + to detection box predictions. + + gt_proposal_deltas: ground-truth rotated box2box transform deltas +""" + + +def fast_rcnn_inference_rotated( + boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image +): + """ + Call `fast_rcnn_inference_single_image_rotated` for all images. + + Args: + boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic + boxes for each image. Element i has shape (Ri, K * 5) if doing + class-specific regression, or (Ri, 5) if doing class-agnostic + regression, where Ri is the number of predicted objects for image i. + This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`. + scores (list[Tensor]): A list of Tensors of predicted class scores for each image. + Element i has shape (Ri, K + 1), where Ri is the number of predicted objects + for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. + image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch. + score_thresh (float): Only return detections with a confidence score exceeding this + threshold. + nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + topk_per_image (int): The number of top scoring detections to return. Set < 0 to return + all detections. + + Returns: + instances: (list[Instances]): A list of N instances, one for each image in the batch, + that stores the topk most confidence detections. + kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates + the corresponding boxes/scores index in [0, Ri) from the input, for image i. + """ + result_per_image = [ + fast_rcnn_inference_single_image_rotated( + boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image + ) + for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) + ] + return [x[0] for x in result_per_image], [x[1] for x in result_per_image] + + +@torch.no_grad() +def fast_rcnn_inference_single_image_rotated( + boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image +): + """ + Single-image inference. Return rotated bounding-box detection results by thresholding + on scores and applying rotated non-maximum suppression (Rotated NMS). + + Args: + Same as `fast_rcnn_inference_rotated`, but with rotated boxes, scores, and image shapes + per image. + + Returns: + Same as `fast_rcnn_inference_rotated`, but for only one image. + """ + valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) + if not valid_mask.all(): + boxes = boxes[valid_mask] + scores = scores[valid_mask] + + B = 5 # box dimension + scores = scores[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // B + # Convert to Boxes to use the `clip` function ... + boxes = RotatedBoxes(boxes.reshape(-1, B)) + boxes.clip(image_shape) + boxes = boxes.tensor.view(-1, num_bbox_reg_classes, B) # R x C x B + # Filter results based on detection scores + filter_mask = scores > score_thresh # R x K + # R' x 2. First column contains indices of the R predictions; + # Second column contains indices of classes. + filter_inds = filter_mask.nonzero() + if num_bbox_reg_classes == 1: + boxes = boxes[filter_inds[:, 0], 0] + else: + boxes = boxes[filter_mask] + scores = scores[filter_mask] + + # Apply per-class Rotated NMS + keep = batched_nms_rotated(boxes, scores, filter_inds[:, 1], nms_thresh) + if topk_per_image >= 0: + keep = keep[:topk_per_image] + boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] + + result = Instances(image_shape) + result.pred_boxes = RotatedBoxes(boxes) + result.scores = scores + result.pred_classes = filter_inds[:, 1] + + return result, filter_inds[:, 0] + + +class RotatedFastRCNNOutputLayers(FastRCNNOutputLayers): + """ + Two linear layers for predicting Rotated Fast R-CNN outputs. + """ + + @classmethod + def from_config(cls, cfg, input_shape): + args = super().from_config(cfg, input_shape) + args["box2box_transform"] = Box2BoxTransformRotated( + weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS + ) + return args + + def inference(self, predictions, proposals): + """ + Returns: + list[Instances]: same as `fast_rcnn_inference_rotated`. + list[Tensor]: same as `fast_rcnn_inference_rotated`. + """ + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + image_shapes = [x.image_size for x in proposals] + + return fast_rcnn_inference_rotated( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + +@ROI_HEADS_REGISTRY.register() +class RROIHeads(StandardROIHeads): + """ + This class is used by Rotated Fast R-CNN to detect rotated boxes. + For now, it only supports box predictions but not mask or keypoints. + """ + + @configurable + def __init__(self, **kwargs): + """ + NOTE: this interface is experimental. + """ + super().__init__(**kwargs) + assert ( + not self.mask_on and not self.keypoint_on + ), "Mask/Keypoints not supported in Rotated ROIHeads." + assert not self.train_on_pred_boxes, "train_on_pred_boxes not implemented for RROIHeads!" + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + # fmt: on + assert pooler_type in ["ROIAlignRotated"], pooler_type + # assume all channel counts are equal + in_channels = [input_shape[f].channels for f in in_features][0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + box_head = build_box_head( + cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution) + ) + # This line is the only difference v.s. StandardROIHeads + box_predictor = RotatedFastRCNNOutputLayers(cfg, box_head.output_shape) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_head": box_head, + "box_predictor": box_predictor, + } + + @torch.no_grad() + def label_and_sample_proposals(self, proposals, targets): + """ + Prepare some proposals to be used to train the RROI heads. + It performs box matching between `proposals` and `targets`, and assigns + training labels to the proposals. + It returns `self.batch_size_per_image` random samples from proposals and groundtruth boxes, + with a fraction of positives that is no larger than `self.positive_sample_fraction. + + Args: + See :meth:`StandardROIHeads.forward` + + Returns: + list[Instances]: length `N` list of `Instances`s containing the proposals + sampled for training. Each `Instances` has the following fields: + - proposal_boxes: the rotated proposal boxes + - gt_boxes: the ground-truth rotated boxes that the proposal is assigned to + (this is only meaningful if the proposal has a label > 0; if label = 0 + then the ground-truth box is random) + - gt_classes: the ground-truth classification lable for each proposal + """ + if self.proposal_append_gt: + proposals = add_ground_truth_to_proposals(targets, proposals) + + proposals_with_gt = [] + + num_fg_samples = [] + num_bg_samples = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + has_gt = len(targets_per_image) > 0 + match_quality_matrix = pairwise_iou_rotated( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix) + sampled_idxs, gt_classes = self._sample_proposals( + matched_idxs, matched_labels, targets_per_image.gt_classes + ) + + proposals_per_image = proposals_per_image[sampled_idxs] + proposals_per_image.gt_classes = gt_classes + + if has_gt: + sampled_targets = matched_idxs[sampled_idxs] + proposals_per_image.gt_boxes = targets_per_image.gt_boxes[sampled_targets] + + num_bg_samples.append((gt_classes == self.num_classes).sum().item()) + num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) + proposals_with_gt.append(proposals_per_image) + + # Log the number of fg/bg samples that are selected for training ROI heads + storage = get_event_storage() + storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples)) + storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples)) + + return proposals_with_gt diff --git a/data_processing/detectron2/detectron2/modeling/sampling.py b/data_processing/detectron2/detectron2/modeling/sampling.py new file mode 100644 index 0000000..a2d0f66 --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/sampling.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from detectron2.layers import nonzero_tuple + +__all__ = ["subsample_labels"] + + +def subsample_labels( + labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int +): + """ + Return `num_samples` (or fewer, if not enough found) + random samples from `labels` which is a mixture of positives & negatives. + It will try to return as many positives as possible without + exceeding `positive_fraction * num_samples`, and then try to + fill the remaining slots with negatives. + + Args: + labels (Tensor): (N, ) label vector with values: + * -1: ignore + * bg_label: background ("negative") class + * otherwise: one or more foreground ("positive") classes + num_samples (int): The total number of labels with value >= 0 to return. + Values that are not sampled will be filled with -1 (ignore). + positive_fraction (float): The number of subsampled labels with values > 0 + is `min(num_positives, int(positive_fraction * num_samples))`. The number + of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. + In order words, if there are not enough positives, the sample is filled with + negatives. If there are also not enough negatives, then as many elements are + sampled as is possible. + bg_label (int): label index of background ("negative") class. + + Returns: + pos_idx, neg_idx (Tensor): + 1D vector of indices. The total length of both is `num_samples` or fewer. + """ + positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] + negative = nonzero_tuple(labels == bg_label)[0] + + num_pos = int(num_samples * positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = num_samples - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] + perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] + + pos_idx = positive[perm1] + neg_idx = negative[perm2] + return pos_idx, neg_idx diff --git a/data_processing/detectron2/detectron2/modeling/test_time_augmentation.py b/data_processing/detectron2/detectron2/modeling/test_time_augmentation.py new file mode 100644 index 0000000..373e6bf --- /dev/null +++ b/data_processing/detectron2/detectron2/modeling/test_time_augmentation.py @@ -0,0 +1,307 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import numpy as np +from contextlib import contextmanager +from itertools import count +from typing import List +import torch +from fvcore.transforms import HFlipTransform, NoOpTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +from detectron2.config import configurable +from detectron2.data.detection_utils import read_image +from detectron2.data.transforms import ( + RandomFlip, + ResizeShortestEdge, + ResizeTransform, + apply_augmentations, +) +from detectron2.structures import Boxes, Instances + +from .meta_arch import GeneralizedRCNN +from .postprocessing import detector_postprocess +from .roi_heads.fast_rcnn import fast_rcnn_inference_single_image + +__all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"] + + +class DatasetMapperTTA: + """ + Implement test-time augmentation for detection data. + It is a callable which takes a dataset dict from a detection dataset, + and returns a list of dataset dicts where the images + are augmented from the input image by the transformations defined in the config. + This is used for test-time augmentation. + """ + + @configurable + def __init__(self, min_sizes: List[int], max_size: int, flip: bool): + """ + Args: + min_sizes: list of short-edge size to resize the image to + max_size: maximum height or width of resized images + flip: whether to apply flipping augmentation + """ + self.min_sizes = min_sizes + self.max_size = max_size + self.flip = flip + + @classmethod + def from_config(cls, cfg): + return { + "min_sizes": cfg.TEST.AUG.MIN_SIZES, + "max_size": cfg.TEST.AUG.MAX_SIZE, + "flip": cfg.TEST.AUG.FLIP, + } + + def __call__(self, dataset_dict): + """ + Args: + dict: a dict in standard model input format. See tutorials for details. + + Returns: + list[dict]: + a list of dicts, which contain augmented version of the input image. + The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``. + Each dict has field "transforms" which is a TransformList, + containing the transforms that are used to generate this image. + """ + numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() + shape = numpy_image.shape + orig_shape = (dataset_dict["height"], dataset_dict["width"]) + if shape[:2] != orig_shape: + # It transforms the "original" image in the dataset to the input image + pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1]) + else: + pre_tfm = NoOpTransform() + + # Create all combinations of augmentations to use + aug_candidates = [] # each element is a list[Augmentation] + for min_size in self.min_sizes: + resize = ResizeShortestEdge(min_size, self.max_size) + aug_candidates.append([resize]) # resize only + if self.flip: + flip = RandomFlip(prob=1.0) + aug_candidates.append([resize, flip]) # resize + flip + + # Apply all the augmentations + ret = [] + for aug in aug_candidates: + new_image, tfms = apply_augmentations(aug, np.copy(numpy_image)) + torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1))) + + dic = copy.deepcopy(dataset_dict) + dic["transforms"] = pre_tfm + tfms + dic["image"] = torch_image + ret.append(dic) + return ret + + +class GeneralizedRCNNWithTTA(nn.Module): + """ + A GeneralizedRCNN with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=3): + """ + Args: + cfg (CfgNode): + model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + assert isinstance( + model, GeneralizedRCNN + ), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model)) + self.cfg = cfg.clone() + assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet" + assert ( + not self.cfg.MODEL.LOAD_PROPOSALS + ), "TTA for pre-computed proposals is not supported yet" + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + @contextmanager + def _turn_off_roi_heads(self, attrs): + """ + Open a context where some heads in `model.roi_heads` are temporarily turned off. + Args: + attr (list[str]): the attribute in `model.roi_heads` which can be used + to turn off a specific head, e.g., "mask_on", "keypoint_on". + """ + roi_heads = self.model.roi_heads + old = {} + for attr in attrs: + try: + old[attr] = getattr(roi_heads, attr) + except AttributeError: + # The head may not be implemented in certain ROIHeads + pass + + if len(old.keys()) == 0: + yield + else: + for attr in old.keys(): + setattr(roi_heads, attr, False) + yield + for attr in old.keys(): + setattr(roi_heads, attr, old[attr]) + + def _batch_inference(self, batched_inputs, detected_instances=None): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + + Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` + """ + if detected_instances is None: + detected_instances = [None] * len(batched_inputs) + + outputs = [] + inputs, instances = [], [] + for idx, input, instance in zip(count(), batched_inputs, detected_instances): + inputs.append(input) + instances.append(instance) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + outputs.extend( + self.model.inference( + inputs, + instances if instances[0] is not None else None, + do_postprocess=False, + ) + ) + inputs, instances = [], [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`GeneralizedRCNN.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + + Returns: + dict: one output dict + """ + orig_shape = (input["height"], input["width"]) + augmented_inputs, tfms = self._get_augmented_inputs(input) + # Detect boxes from all augmented versions + with self._turn_off_roi_heads(["mask_on", "keypoint_on"]): + # temporarily disable roi heads + all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) + # merge all detected boxes to obtain final predictions for boxes + merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) + + if self.cfg.MODEL.MASK_ON: + # Use the detected boxes to obtain masks + augmented_instances = self._rescale_detected_boxes( + augmented_inputs, merged_instances, tfms + ) + # run forward on the detected boxes + outputs = self._batch_inference(augmented_inputs, augmented_instances) + # Delete now useless variables to avoid being out of memory + del augmented_inputs, augmented_instances + # average the predictions + merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms) + merged_instances = detector_postprocess(merged_instances, *orig_shape) + return {"instances": merged_instances} + else: + return {"instances": merged_instances} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms + + def _get_augmented_boxes(self, augmented_inputs, tfms): + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # 2: union the results + all_boxes = [] + all_scores = [] + all_classes = [] + for output, tfm in zip(outputs, tfms): + # Need to inverse the transforms on boxes, to obtain results on original image + pred_boxes = output.pred_boxes.tensor + original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy()) + all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device)) + + all_scores.extend(output.scores) + all_classes.extend(output.pred_classes) + all_boxes = torch.cat(all_boxes, dim=0) + return all_boxes, all_scores, all_classes + + def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): + # select from the union of all results + num_boxes = len(all_boxes) + num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES + # +1 because fast_rcnn_inference expects background scores as well + all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) + for idx, cls, score in zip(count(), all_classes, all_scores): + all_scores_2d[idx, cls] = score + + merged_instances, _ = fast_rcnn_inference_single_image( + all_boxes, + all_scores_2d, + shape_hw, + 1e-8, + self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, + self.cfg.TEST.DETECTIONS_PER_IMAGE, + ) + + return merged_instances + + def _rescale_detected_boxes(self, augmented_inputs, merged_instances, tfms): + augmented_instances = [] + for input, tfm in zip(augmented_inputs, tfms): + # Transform the target box to the augmented image's coordinate space + pred_boxes = merged_instances.pred_boxes.tensor.cpu().numpy() + pred_boxes = torch.from_numpy(tfm.apply_box(pred_boxes)) + + aug_instances = Instances( + image_size=input["image"].shape[1:3], + pred_boxes=Boxes(pred_boxes), + pred_classes=merged_instances.pred_classes, + scores=merged_instances.scores, + ) + augmented_instances.append(aug_instances) + return augmented_instances + + def _reduce_pred_masks(self, outputs, tfms): + # Should apply inverse transforms on masks. + # We assume only resize & flip are used. pred_masks is a scale-invariant + # representation, so we handle flip specially + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + output.pred_masks = output.pred_masks.flip(dims=[3]) + all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0) + avg_pred_masks = torch.mean(all_pred_masks, dim=0) + return avg_pred_masks diff --git a/data_processing/detectron2/detectron2/projects/README.md b/data_processing/detectron2/detectron2/projects/README.md new file mode 100644 index 0000000..95afe7f --- /dev/null +++ b/data_processing/detectron2/detectron2/projects/README.md @@ -0,0 +1,2 @@ + +Projects live in the [`projects` directory](../../projects) under the root of this repository, but not here. diff --git a/data_processing/detectron2/detectron2/projects/__init__.py b/data_processing/detectron2/detectron2/projects/__init__.py new file mode 100644 index 0000000..b2d0540 --- /dev/null +++ b/data_processing/detectron2/detectron2/projects/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib.abc +import importlib.util +from pathlib import Path + +__all__ = [] + +_PROJECTS = { + "point_rend": "PointRend", + "deeplab": "DeepLab", + "panoptic_deeplab": "Panoptic-DeepLab", +} +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent / "projects" + +if _PROJECT_ROOT.is_dir(): + # This is true only for in-place installation (pip install -e, setup.py develop), + # where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230 + + class _D2ProjectsFinder(importlib.abc.MetaPathFinder): + def find_spec(self, name, path, target=None): + if not name.startswith("detectron2.projects."): + return + project_name = name.split(".")[-1] + project_dir = _PROJECTS.get(project_name) + if not project_dir: + return + target_file = _PROJECT_ROOT / f"{project_dir}/{project_name}/__init__.py" + if not target_file.is_file(): + return + return importlib.util.spec_from_file_location(name, target_file) + + import sys + + sys.meta_path.append(_D2ProjectsFinder()) diff --git a/data_processing/detectron2/detectron2/solver/__init__.py b/data_processing/detectron2/detectron2/solver/__init__.py new file mode 100644 index 0000000..7e36c64 --- /dev/null +++ b/data_processing/detectron2/detectron2/solver/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params +from .lr_scheduler import ( + LRMultiplier, + LRScheduler, + WarmupCosineLR, + WarmupMultiStepLR, + WarmupParamScheduler, +) + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/solver/build.py b/data_processing/detectron2/detectron2/solver/build.py new file mode 100644 index 0000000..6ce25b3 --- /dev/null +++ b/data_processing/detectron2/detectron2/solver/build.py @@ -0,0 +1,310 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import logging +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union +import torch +from fvcore.common.param_scheduler import ( + CosineParamScheduler, + MultiStepParamScheduler, + StepWithFixedGammaParamScheduler, +) + +from detectron2.config import CfgNode +from detectron2.utils.env import TORCH_VERSION + +from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler + +_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] +_GradientClipper = Callable[[_GradientClipperInput], None] + + +class GradientClipType(Enum): + VALUE = "value" + NORM = "norm" + + +def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: + """ + Creates gradient clipping closure to clip by value or by norm, + according to the provided config. + """ + cfg = copy.deepcopy(cfg) + + def clip_grad_norm(p: _GradientClipperInput): + torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) + + def clip_grad_value(p: _GradientClipperInput): + torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) + + _GRADIENT_CLIP_TYPE_TO_CLIPPER = { + GradientClipType.VALUE: clip_grad_value, + GradientClipType.NORM: clip_grad_norm, + } + return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] + + +def _generate_optimizer_class_with_gradient_clipping( + optimizer: Type[torch.optim.Optimizer], + *, + per_param_clipper: Optional[_GradientClipper] = None, + global_clipper: Optional[_GradientClipper] = None, +) -> Type[torch.optim.Optimizer]: + """ + Dynamically creates a new type that inherits the type of a given instance + and overrides the `step` method to add gradient clipping + """ + assert ( + per_param_clipper is None or global_clipper is None + ), "Not allowed to use both per-parameter clipping and global clipping" + + def optimizer_wgc_step(self, closure=None): + if per_param_clipper is not None: + for group in self.param_groups: + for p in group["params"]: + per_param_clipper(p) + else: + # global clipper for future use with detr + # (https://github.com/facebookresearch/detr/pull/287) + all_params = itertools.chain(*[g["params"] for g in self.param_groups]) + global_clipper(all_params) + super(type(self), self).step(closure) + + OptimizerWithGradientClip = type( + optimizer.__name__ + "WithGradientClip", + (optimizer,), + {"step": optimizer_wgc_step}, + ) + return OptimizerWithGradientClip + + +def maybe_add_gradient_clipping( + cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] +) -> Type[torch.optim.Optimizer]: + """ + If gradient clipping is enabled through config options, wraps the existing + optimizer type to become a new dynamically created class OptimizerWithGradientClip + that inherits the given optimizer and overrides the `step` method to + include gradient clipping. + + Args: + cfg: CfgNode, configuration options + optimizer: type. A subclass of torch.optim.Optimizer + + Return: + type: either the input `optimizer` (if gradient clipping is disabled), or + a subclass of it with gradient clipping included in the `step` method. + """ + if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: + return optimizer + if isinstance(optimizer, torch.optim.Optimizer): + optimizer_type = type(optimizer) + else: + assert issubclass(optimizer, torch.optim.Optimizer), optimizer + optimizer_type = optimizer + + grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) + OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( + optimizer_type, per_param_clipper=grad_clipper + ) + if isinstance(optimizer, torch.optim.Optimizer): + optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended + return optimizer + else: + return OptimizerWithGradientClip + + +def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build an optimizer from config. + """ + params = get_default_optimizer_params( + model, + base_lr=cfg.SOLVER.BASE_LR, + weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, + bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, + weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, + ) + sgd_args = { + "params": params, + "lr": cfg.SOLVER.BASE_LR, + "momentum": cfg.SOLVER.MOMENTUM, + "nesterov": cfg.SOLVER.NESTEROV, + "weight_decay": cfg.SOLVER.WEIGHT_DECAY, + } + if TORCH_VERSION >= (1, 12): + sgd_args["foreach"] = True + return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args)) + + +def get_default_optimizer_params( + model: torch.nn.Module, + base_lr: Optional[float] = None, + weight_decay: Optional[float] = None, + weight_decay_norm: Optional[float] = None, + bias_lr_factor: Optional[float] = 1.0, + weight_decay_bias: Optional[float] = None, + lr_factor_func: Optional[Callable] = None, + overrides: Optional[Dict[str, Dict[str, float]]] = None, +) -> List[Dict[str, Any]]: + """ + Get default param list for optimizer, with support for a few types of + overrides. If no overrides needed, this is equivalent to `model.parameters()`. + + Args: + base_lr: lr for every group by default. Can be omitted to use the one in optimizer. + weight_decay: weight decay for every group by default. Can be omitted to use the one + in optimizer. + weight_decay_norm: override weight decay for params in normalization layers + bias_lr_factor: multiplier of lr for bias parameters. + weight_decay_bias: override weight decay for bias parameters. + lr_factor_func: function to calculate lr decay rate by mapping the parameter names to + corresponding lr decay rate. Note that setting this option requires + also setting ``base_lr``. + overrides: if not `None`, provides values for optimizer hyperparameters + (LR, weight decay) for module parameters with a given name; e.g. + ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and + weight decay values for all module parameters named `embedding`. + + For common detection models, ``weight_decay_norm`` is the only option + needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings + from Detectron1 that are not found useful. + + Example: + :: + torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), + lr=0.01, weight_decay=1e-4, momentum=0.9) + """ + if overrides is None: + overrides = {} + defaults = {} + if base_lr is not None: + defaults["lr"] = base_lr + if weight_decay is not None: + defaults["weight_decay"] = weight_decay + bias_overrides = {} + if bias_lr_factor is not None and bias_lr_factor != 1.0: + # NOTE: unlike Detectron v1, we now by default make bias hyperparameters + # exactly the same as regular weights. + if base_lr is None: + raise ValueError("bias_lr_factor requires base_lr") + bias_overrides["lr"] = base_lr * bias_lr_factor + if weight_decay_bias is not None: + bias_overrides["weight_decay"] = weight_decay_bias + if len(bias_overrides): + if "bias" in overrides: + raise ValueError("Conflicting overrides for 'bias'") + overrides["bias"] = bias_overrides + if lr_factor_func is not None: + if base_lr is None: + raise ValueError("lr_factor_func requires base_lr") + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + + hyperparams = copy.copy(defaults) + if isinstance(module, norm_module_types) and weight_decay_norm is not None: + hyperparams["weight_decay"] = weight_decay_norm + if lr_factor_func is not None: + hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") + + hyperparams.update(overrides.get(module_param_name, {})) + params.append({"params": [value], **hyperparams}) + return reduce_param_groups(params) + + +def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # Transform parameter groups into per-parameter structure. + # Later items in `params` can overwrite parameters set in previous items. + ret = defaultdict(dict) + for item in params: + assert "params" in item + cur_params = {x: y for x, y in item.items() if x != "params"} + for param in item["params"]: + ret[param].update({"params": [param], **cur_params}) + return list(ret.values()) + + +def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # Reorganize the parameter groups and merge duplicated groups. + # The number of parameter groups needs to be as small as possible in order + # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead + # of using a parameter_group per single parameter, we reorganize the + # parameter groups and merge duplicated groups. This approach speeds + # up multi-tensor optimizer significantly. + params = _expand_param_groups(params) + groups = defaultdict(list) # re-group all parameter groups by their hyperparams + for item in params: + cur_params = tuple((x, y) for x, y in item.items() if x != "params") + groups[cur_params].extend(item["params"]) + ret = [] + for param_keys, param_values in groups.items(): + cur = {kv[0]: kv[1] for kv in param_keys} + cur["params"] = param_values + ret.append(cur) + return ret + + +def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler: + """ + Build a LR scheduler from config. + """ + name = cfg.SOLVER.LR_SCHEDULER_NAME + + if name == "WarmupMultiStepLR": + steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER] + if len(steps) != len(cfg.SOLVER.STEPS): + logger = logging.getLogger(__name__) + logger.warning( + "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. " + "These values will be ignored." + ) + sched = MultiStepParamScheduler( + values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)], + milestones=steps, + num_updates=cfg.SOLVER.MAX_ITER, + ) + elif name == "WarmupCosineLR": + end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR + assert end_value >= 0.0 and end_value <= 1.0, end_value + sched = CosineParamScheduler(1, end_value) + elif name == "WarmupStepWithFixedGammaLR": + sched = StepWithFixedGammaParamScheduler( + base_value=1.0, + gamma=cfg.SOLVER.GAMMA, + num_decays=cfg.SOLVER.NUM_DECAYS, + num_updates=cfg.SOLVER.MAX_ITER, + ) + else: + raise ValueError("Unknown LR scheduler: {}".format(name)) + + sched = WarmupParamScheduler( + sched, + cfg.SOLVER.WARMUP_FACTOR, + min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0), + cfg.SOLVER.WARMUP_METHOD, + cfg.SOLVER.RESCALE_INTERVAL, + ) + return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) diff --git a/data_processing/detectron2/detectron2/solver/lr_scheduler.py b/data_processing/detectron2/detectron2/solver/lr_scheduler.py new file mode 100644 index 0000000..d6aed2b --- /dev/null +++ b/data_processing/detectron2/detectron2/solver/lr_scheduler.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from bisect import bisect_right +from typing import List +import torch +from fvcore.common.param_scheduler import ( + CompositeParamScheduler, + ConstantParamScheduler, + LinearParamScheduler, + ParamScheduler, +) + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +logger = logging.getLogger(__name__) + + +class WarmupParamScheduler(CompositeParamScheduler): + """ + Add an initial warmup stage to another scheduler. + """ + + def __init__( + self, + scheduler: ParamScheduler, + warmup_factor: float, + warmup_length: float, + warmup_method: str = "linear", + rescale_interval: bool = False, + ): + """ + Args: + scheduler: warmup will be added at the beginning of this scheduler + warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 + warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire + training, e.g. 0.01 + warmup_method: one of "linear" or "constant" + rescale_interval: whether we will rescale the interval of the scheduler after + warmup + """ + end_value = scheduler(warmup_length) # the value to reach when warmup ends + start_value = warmup_factor * scheduler(0.0) + if warmup_method == "constant": + warmup = ConstantParamScheduler(start_value) + elif warmup_method == "linear": + warmup = LinearParamScheduler(start_value, end_value) + else: + raise ValueError("Unknown warmup method: {}".format(warmup_method)) + super().__init__( + [warmup, scheduler], + interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"], + lengths=[warmup_length, 1 - warmup_length], + ) + + +class LRMultiplier(LRScheduler): + """ + A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the + learning rate of each param in the optimizer. + Every step, the learning rate of each parameter becomes its initial value + multiplied by the output of the given :class:`ParamScheduler`. + + The absolute learning rate value of each parameter can be different. + This scheduler can be used as long as the relative scale among them do + not change during training. + + Examples: + :: + LRMultiplier( + opt, + WarmupParamScheduler( + MultiStepParamScheduler( + [1, 0.1, 0.01], + milestones=[60000, 80000], + num_updates=90000, + ), 0.001, 100 / 90000 + ), + max_iter=90000 + ) + """ + + # NOTES: in the most general case, every LR can use its own scheduler. + # Supporting this requires interaction with the optimizer when its parameter + # group is initialized. For example, classyvision implements its own optimizer + # that allows different schedulers for every parameter group. + # To avoid this complexity, we use this class to support the most common cases + # where the relative scale among all LRs stay unchanged during training. In this + # case we only need a total of one scheduler that defines the relative LR multiplier. + + def __init__( + self, + optimizer: torch.optim.Optimizer, + multiplier: ParamScheduler, + max_iter: int, + last_iter: int = -1, + ): + """ + Args: + optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``. + ``last_iter`` is the same as ``last_epoch``. + multiplier: a fvcore ParamScheduler that defines the multiplier on + every LR of the optimizer + max_iter: the total number of training iterations + """ + if not isinstance(multiplier, ParamScheduler): + raise ValueError( + "_LRMultiplier(multiplier=) must be an instance of fvcore " + f"ParamScheduler. Got {multiplier} instead." + ) + self._multiplier = multiplier + self._max_iter = max_iter + super().__init__(optimizer, last_epoch=last_iter) + + def state_dict(self): + # fvcore schedulers are stateless. Only keep pytorch scheduler states + return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} + + def get_lr(self) -> List[float]: + multiplier = self._multiplier(self.last_epoch / self._max_iter) + return [base_lr * multiplier for base_lr in self.base_lrs] + + +""" +Content below is no longer needed! +""" + +# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes +# only on epoch boundaries. We typically use iteration based schedules instead. +# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean +# "iteration" instead. + +# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating +# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. + + +class WarmupMultiStepLR(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + ): + logger.warning( + "WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" + ) + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", milestones + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + return [ + base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +class WarmupCosineLR(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + max_iters: int, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + ): + logger.warning( + "WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" + ) + self.max_iters = max_iters + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + # Different definitions of half-cosine with warmup are possible. For + # simplicity we multiply the standard half-cosine schedule by the warmup + # factor. An alternative is to start the period of the cosine at warmup_iters + # instead of at 0. In the case that warmup_iters << max_iters the two are + # very close to each other. + return [ + base_lr + * warmup_factor + * 0.5 + * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +def _get_warmup_factor_at_iter( + method: str, iter: int, warmup_iters: int, warmup_factor: float +) -> float: + """ + Return the learning rate warmup factor at a specific iteration. + See :paper:`ImageNet in 1h` for more details. + + Args: + method (str): warmup method; either "constant" or "linear". + iter (int): iteration at which to calculate the warmup factor. + warmup_iters (int): the number of warmup iterations. + warmup_factor (float): the base warmup factor (the meaning changes according + to the method used). + + Returns: + float: the effective warmup factor at the given iteration. + """ + if iter >= warmup_iters: + return 1.0 + + if method == "constant": + return warmup_factor + elif method == "linear": + alpha = iter / warmup_iters + return warmup_factor * (1 - alpha) + alpha + else: + raise ValueError("Unknown warmup method: {}".format(method)) diff --git a/data_processing/detectron2/detectron2/structures/__init__.py b/data_processing/detectron2/detectron2/structures/__init__.py new file mode 100644 index 0000000..f3ee605 --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .boxes import Boxes, BoxMode, pairwise_iou, pairwise_ioa, pairwise_point_box_distance +from .image_list import ImageList + +from .instances import Instances +from .keypoints import Keypoints, heatmaps_to_keypoints +from .masks import BitMasks, PolygonMasks, polygons_to_bitmask, ROIMasks +from .rotated_boxes import RotatedBoxes +from .rotated_boxes import pairwise_iou as pairwise_iou_rotated + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +from detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/data_processing/detectron2/detectron2/structures/boxes.py b/data_processing/detectron2/detectron2/structures/boxes.py new file mode 100644 index 0000000..fd396f6 --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/boxes.py @@ -0,0 +1,425 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +import numpy as np +from enum import IntEnum, unique +from typing import List, Tuple, Union +import torch +from torch import device + +_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] + + +@unique +class BoxMode(IntEnum): + """ + Enum of different ways to represent a box. + """ + + XYXY_ABS = 0 + """ + (x0, y0, x1, y1) in absolute floating points coordinates. + The coordinates in range [0, width or height]. + """ + XYWH_ABS = 1 + """ + (x0, y0, w, h) in absolute floating points coordinates. + """ + XYXY_REL = 2 + """ + Not yet supported! + (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image. + """ + XYWH_REL = 3 + """ + Not yet supported! + (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image. + """ + XYWHA_ABS = 4 + """ + (xc, yc, w, h, a) in absolute floating points coordinates. + (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw. + """ + + @staticmethod + def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType: + """ + Args: + box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5 + from_mode, to_mode (BoxMode) + + Returns: + The converted box of the same type. + """ + if from_mode == to_mode: + return box + + original_type = type(box) + is_numpy = isinstance(box, np.ndarray) + single_box = isinstance(box, (list, tuple)) + if single_box: + assert len(box) == 4 or len(box) == 5, ( + "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor," + " where k == 4 or 5" + ) + arr = torch.tensor(box)[None, :] + else: + # avoid modifying the input box + if is_numpy: + arr = torch.from_numpy(np.asarray(box)).clone() + else: + arr = box.clone() + + assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + + if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: + assert ( + arr.shape[-1] == 5 + ), "The last dimension of input shape must be 5 for XYWHA format" + original_dtype = arr.dtype + arr = arr.double() + + w = arr[:, 2] + h = arr[:, 3] + a = arr[:, 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + new_w = c * w + s * h + new_h = c * h + s * w + + # convert center to top-left corner + arr[:, 0] -= new_w / 2.0 + arr[:, 1] -= new_h / 2.0 + # bottom-right corner + arr[:, 2] = arr[:, 0] + new_w + arr[:, 3] = arr[:, 1] + new_h + + arr = arr[:, :4].to(dtype=original_dtype) + elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS: + original_dtype = arr.dtype + arr = arr.double() + arr[:, 0] += arr[:, 2] / 2.0 + arr[:, 1] += arr[:, 3] / 2.0 + angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype) + arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) + else: + if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS: + arr[:, 2] += arr[:, 0] + arr[:, 3] += arr[:, 1] + elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS: + arr[:, 2] -= arr[:, 0] + arr[:, 3] -= arr[:, 1] + else: + raise NotImplementedError( + "Conversion from BoxMode {} to {} is not supported yet".format( + from_mode, to_mode + ) + ) + + if single_box: + return original_type(arr.flatten().tolist()) + if is_numpy: + return arr.numpy() + else: + return arr + + +class Boxes: + """ + This structure stores a list of boxes as a Nx4 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + + Attributes: + tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2). + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). + """ + if not isinstance(tensor, torch.Tensor): + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu")) + else: + tensor = tensor.to(torch.float32) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32) + assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() + + self.tensor = tensor + + def clone(self) -> "Boxes": + """ + Clone the Boxes. + + Returns: + Boxes + """ + return Boxes(self.tensor.clone()) + + def to(self, device: torch.device): + # Boxes are assumed float32 and does not support to(dtype) + return Boxes(self.tensor.to(device=device)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + return area + + def clip(self, box_size: Tuple[int, int]) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + Args: + box_size (height, width): The clipping box's size. + """ + assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!" + h, w = box_size + x1 = self.tensor[:, 0].clamp(min=0, max=w) + y1 = self.tensor[:, 1].clamp(min=0, max=h) + x2 = self.tensor[:, 2].clamp(min=0, max=w) + y2 = self.tensor[:, 3].clamp(min=0, max=h) + self.tensor = torch.stack((x1, y1, x2, y2), dim=-1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: + a binary vector which represents whether each box is empty + (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] - box[:, 0] + heights = box[:, 3] - box[:, 1] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "Boxes": + """ + Args: + item: int, slice, or a BoolTensor + + Returns: + Boxes: Create a new :class:`Boxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned Boxes might share storage with this Boxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Boxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert b.dim() == 2, "Indexing on Boxes with {} failed to return a matrix!".format(item) + return Boxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "Boxes(" + str(self.tensor) + ")" + + def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box. + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + inds_inside = ( + (self.tensor[..., 0] >= -boundary_threshold) + & (self.tensor[..., 1] >= -boundary_threshold) + & (self.tensor[..., 2] < width + boundary_threshold) + & (self.tensor[..., 3] < height + boundary_threshold) + ) + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2 + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the box with horizontal and vertical scaling factors + """ + self.tensor[:, 0::2] *= scale_x + self.tensor[:, 1::2] *= scale_y + + @classmethod + def cat(cls, boxes_list: List["Boxes"]) -> "Boxes": + """ + Concatenates a list of Boxes into a single Boxes + + Arguments: + boxes_list (list[Boxes]) + + Returns: + Boxes: the concatenated Boxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, Boxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> device: + return self.tensor.device + + # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript + # https://github.com/pytorch/pytorch/issues/18627 + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (4,) at a time. + """ + yield from self.tensor + + +def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, + compute the intersection area between __all__ N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax) + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: intersection, sized [N,M]. + """ + boxes1, boxes2 = boxes1.tensor, boxes2.tensor + width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( + boxes1[:, None, :2], boxes2[:, :2] + ) # [N,M,2] + + width_height.clamp_(min=0) # [N,M,2] + intersection = width_height.prod(dim=2) # [N,M] + return intersection + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, compute the IoU + (intersection over union) between **all** N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + iou = torch.where( + inter > 0, + inter / (area1[:, None] + area2 - inter), + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + return iou + + +def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoA, sized [N,M]. + """ + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + ioa = torch.where( + inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device) + ) + return ioa + + +def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes): + """ + Pairwise distance between N points and M boxes. The distance between a + point and a box is represented by the distance from the point to 4 edges + of the box. Distances are all positive when the point is inside the box. + + Args: + points: Nx2 coordinates. Each row is (x, y) + boxes: M boxes + + Returns: + Tensor: distances of size (N, M, 4). The 4 values are distances from + the point to the left, top, right, bottom of the box. + """ + x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M) + return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) + + +def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Compute pairwise intersection over union (IOU) of two sets of matched + boxes that have the same number of boxes. + Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix. + + Args: + boxes1 (Boxes): bounding boxes, sized [N,4]. + boxes2 (Boxes): same length as boxes1 + Returns: + Tensor: iou, sized [N]. + """ + assert len(boxes1) == len( + boxes2 + ), "boxlists should have the same" "number of entries, got {}, {}".format( + len(boxes1), len(boxes2) + ) + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [N] + box1, box2 = boxes1.tensor, boxes2.tensor + lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2] + rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2] + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + iou = inter / (area1 + area2 - inter) # [N] + return iou diff --git a/data_processing/detectron2/detectron2/structures/image_list.py b/data_processing/detectron2/detectron2/structures/image_list.py new file mode 100644 index 0000000..f78cae7 --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/image_list.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import division +from typing import Any, Dict, List, Optional, Tuple +import torch +from torch import device +from torch.nn import functional as F + +from detectron2.layers.wrappers import move_device_like, shapes_to_tensor + + +class ImageList(object): + """ + Structure that holds a list of images (of possibly + varying sizes) as a single tensor. + This works by padding the images to the same size. + The original sizes of each image is stored in `image_sizes`. + + Attributes: + image_sizes (list[tuple[int, int]]): each tuple is (h, w). + During tracing, it becomes list[Tensor] instead. + """ + + def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]): + """ + Arguments: + tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1 + image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can + be smaller than (H, W) due to padding. + """ + self.tensor = tensor + self.image_sizes = image_sizes + + def __len__(self) -> int: + return len(self.image_sizes) + + def __getitem__(self, idx) -> torch.Tensor: + """ + Access the individual image in its original size. + + Args: + idx: int or slice + + Returns: + Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1 + """ + size = self.image_sizes[idx] + return self.tensor[idx, ..., : size[0], : size[1]] + + @torch.jit.unused + def to(self, *args: Any, **kwargs: Any) -> "ImageList": + cast_tensor = self.tensor.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) + + @property + def device(self) -> device: + return self.tensor.device + + @staticmethod + def from_tensors( + tensors: List[torch.Tensor], + size_divisibility: int = 0, + pad_value: float = 0.0, + padding_constraints: Optional[Dict[str, int]] = None, + ) -> "ImageList": + """ + Args: + tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or + (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded + to the same shape with `pad_value`. + size_divisibility (int): If `size_divisibility > 0`, add padding to ensure + the common height and width is divisible by `size_divisibility`. + This depends on the model and many models need a divisibility of 32. + pad_value (float): value to pad. + padding_constraints (optional[Dict]): If given, it would follow the format as + {"size_divisibility": int, "square_size": int}, where `size_divisibility` will + overwrite the above one if presented and `square_size` indicates the + square padding size if `square_size` > 0. + Returns: + an `ImageList`. + """ + assert len(tensors) > 0 + assert isinstance(tensors, (tuple, list)) + for t in tensors: + assert isinstance(t, torch.Tensor), type(t) + assert t.shape[:-2] == tensors[0].shape[:-2], t.shape + + image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors] + image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes] + max_size = torch.stack(image_sizes_tensor).max(0).values + + if padding_constraints is not None: + square_size = padding_constraints.get("square_size", 0) + if square_size > 0: + # pad to square. + max_size[0] = max_size[1] = square_size + if "size_divisibility" in padding_constraints: + size_divisibility = padding_constraints["size_divisibility"] + if size_divisibility > 1: + stride = size_divisibility + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride + + # handle weirdness of scripting and tracing ... + if torch.jit.is_scripting(): + max_size: List[int] = max_size.to(dtype=torch.long).tolist() + else: + if torch.jit.is_tracing(): + image_sizes = image_sizes_tensor + + if len(tensors) == 1: + # This seems slightly (2%) faster. + # TODO: check whether it's faster for multiple images as well + image_size = image_sizes[0] + padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]] + batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0) + else: + # max_size can be a tensor in tracing mode, therefore convert to list + batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) + device = ( + None if torch.jit.is_scripting() else ("cpu" if torch.jit.is_tracing() else None) + ) + batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device) + batched_imgs = move_device_like(batched_imgs, tensors[0]) + for i, img in enumerate(tensors): + # Use `batched_imgs` directly instead of `img, pad_img = zip(tensors, batched_imgs)` + # Tracing mode cannot capture `copy_()` of temporary locals + batched_imgs[i, ..., : img.shape[-2], : img.shape[-1]].copy_(img) + + return ImageList(batched_imgs.contiguous(), image_sizes) diff --git a/data_processing/detectron2/detectron2/structures/instances.py b/data_processing/detectron2/detectron2/structures/instances.py new file mode 100644 index 0000000..c9579bc --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/instances.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import warnings +from typing import Any, Dict, List, Tuple, Union +import torch + + +class Instances: + """ + This class represents a list of instances in an image. + It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields". + All fields must have the same ``__len__`` which is the number of instances. + + All other (non-field) attributes of this class are considered private: + they must start with '_' and are not modifiable by a user. + + Some basic usage: + + 1. Set/get/check a field: + + .. code-block:: python + + instances.gt_boxes = Boxes(...) + print(instances.pred_masks) # a tensor of shape (N, H, W) + print('gt_masks' in instances) + + 2. ``len(instances)`` returns the number of instances + 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields + and returns a new :class:`Instances`. + Typically, ``indices`` is a integer vector of indices, + or a binary mask of length ``num_instances`` + + .. code-block:: python + + category_3_detections = instances[instances.pred_classes == 3] + confident_detections = instances[instances.scores > 0.9] + """ + + def __init__(self, image_size: Tuple[int, int], **kwargs: Any): + """ + Args: + image_size (height, width): the spatial size of the image. + kwargs: fields to add to this `Instances`. + """ + self._image_size = image_size + self._fields: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.set(k, v) + + @property + def image_size(self) -> Tuple[int, int]: + """ + Returns: + tuple: height, width + """ + return self._image_size + + def __setattr__(self, name: str, val: Any) -> None: + if name.startswith("_"): + super().__setattr__(name, val) + else: + self.set(name, val) + + def __getattr__(self, name: str) -> Any: + if name == "_fields" or name not in self._fields: + raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) + return self._fields[name] + + def set(self, name: str, value: Any) -> None: + """ + Set the field named `name` to `value`. + The length of `value` must be the number of instances, + and must agree with other existing fields in this object. + """ + with warnings.catch_warnings(record=True): + data_len = len(value) + if len(self._fields): + assert ( + len(self) == data_len + ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) + self._fields[name] = value + + def has(self, name: str) -> bool: + """ + Returns: + bool: whether the field called `name` exists. + """ + return name in self._fields + + def remove(self, name: str) -> None: + """ + Remove the field called `name`. + """ + del self._fields[name] + + def get(self, name: str) -> Any: + """ + Returns the field called `name`. + """ + return self._fields[name] + + def get_fields(self) -> Dict[str, Any]: + """ + Returns: + dict: a dict which maps names (str) to data of the fields + + Modifying the returned dict will modify this instance. + """ + return self._fields + + # Tensor-like methods + def to(self, *args: Any, **kwargs: Any) -> "Instances": + """ + Returns: + Instances: all fields are called with a `to(device)`, if the field has this method. + """ + ret = Instances(self._image_size) + for k, v in self._fields.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + ret.set(k, v) + return ret + + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances": + """ + Args: + item: an index-like object and will be used to index all the fields. + + Returns: + If `item` is a string, return the data in the corresponding field. + Otherwise, returns an `Instances` where all fields are indexed by `item`. + """ + if type(item) == int: + if item >= len(self) or item < -len(self): + raise IndexError("Instances index out of range!") + else: + item = slice(item, None, len(self)) + + ret = Instances(self._image_size) + for k, v in self._fields.items(): + ret.set(k, v[item]) + return ret + + def __len__(self) -> int: + for v in self._fields.values(): + # use __len__ because len() has to be int and is not friendly to tracing + return v.__len__() + raise NotImplementedError("Empty Instances does not support __len__!") + + def __iter__(self): + raise NotImplementedError("`Instances` object is not iterable!") + + @staticmethod + def cat(instance_lists: List["Instances"]) -> "Instances": + """ + Args: + instance_lists (list[Instances]) + + Returns: + Instances + """ + assert all(isinstance(i, Instances) for i in instance_lists) + assert len(instance_lists) > 0 + if len(instance_lists) == 1: + return instance_lists[0] + + image_size = instance_lists[0].image_size + if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing + for i in instance_lists[1:]: + assert i.image_size == image_size + ret = Instances(image_size) + for k in instance_lists[0]._fields.keys(): + values = [i.get(k) for i in instance_lists] + v0 = values[0] + if isinstance(v0, torch.Tensor): + values = torch.cat(values, dim=0) + elif isinstance(v0, list): + values = list(itertools.chain(*values)) + elif hasattr(type(v0), "cat"): + values = type(v0).cat(values) + else: + raise ValueError("Unsupported type {} for concatenation".format(type(v0))) + ret.set(k, values) + return ret + + def __str__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={}, ".format(len(self)) + s += "image_height={}, ".format(self._image_size[0]) + s += "image_width={}, ".format(self._image_size[1]) + s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items()))) + return s + + __repr__ = __str__ diff --git a/data_processing/detectron2/detectron2/structures/keypoints.py b/data_processing/detectron2/detectron2/structures/keypoints.py new file mode 100644 index 0000000..b93ebed --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/keypoints.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Any, List, Tuple, Union +import torch +from torch.nn import functional as F + + +class Keypoints: + """ + Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property + containing the x,y location and visibility flag of each keypoint. This tensor has shape + (N, K, 3) where N is the number of instances and K is the number of keypoints per instance. + + The visibility flag follows the COCO format and must be one of three integers: + + * v=0: not labeled (in which case x=y=0) + * v=1: labeled but not visible + * v=2: labeled and visible + """ + + def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]): + """ + Arguments: + keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint. + The shape should be (N, K, 3) where N is the number of + instances, and K is the number of keypoints per instance. + """ + device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu") + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) + assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape + self.tensor = keypoints + + def __len__(self) -> int: + return self.tensor.size(0) + + def to(self, *args: Any, **kwargs: Any) -> "Keypoints": + return type(self)(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor: + """ + Convert keypoint annotations to a heatmap of one-hot labels for training, + as described in :paper:`Mask R-CNN`. + + Arguments: + boxes: Nx4 tensor, the boxes to draw the keypoints to + + Returns: + heatmaps: + A tensor of shape (N, K), each element is integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: + A tensor of shape (N, K) containing whether each keypoint is in the roi or not. + """ + return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size) + + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints": + """ + Create a new `Keypoints` by indexing on this `Keypoints`. + + The following usage are allowed: + + 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance. + 2. `new_kpts = kpts[2:10]`: return a slice of key points. + 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor + with `length = len(kpts)`. Nonzero elements in the vector will be selected. + + Note that the returned Keypoints might share storage with this Keypoints, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Keypoints([self.tensor[item]]) + return Keypoints(self.tensor[item]) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @staticmethod + def cat(keypoints_list: List["Keypoints"]) -> "Keypoints": + """ + Concatenates a list of Keypoints into a single Keypoints + + Arguments: + keypoints_list (list[Keypoints]) + + Returns: + Keypoints: the concatenated Keypoints + """ + assert isinstance(keypoints_list, (list, tuple)) + assert len(keypoints_list) > 0 + assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list) + + cat_kpts = type(keypoints_list[0])( + torch.cat([kpts.tensor for kpts in keypoints_list], dim=0) + ) + return cat_kpts + + +# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop) +def _keypoints_to_heatmap( + keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space. + + Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the + closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the + continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"): + d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + + Arguments: + keypoints: tensor of keypoint locations in of shape (N, K, 3). + rois: Nx4 tensor of rois in xyxy format + heatmap_size: integer side length of square heatmap. + + Returns: + heatmaps: A tensor of shape (N, K) containing an integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: A tensor of shape (N, K) containing whether each keypoint is in + the roi or not. + """ + + if rois.numel() == 0: + return rois.new().long(), rois.new().long() + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) + scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) + + offset_x = offset_x[:, None] + offset_y = offset_y[:, None] + scale_x = scale_x[:, None] + scale_y = scale_y[:, None] + + x = keypoints[..., 0] + y = keypoints[..., 1] + + x_boundary_inds = x == rois[:, 2][:, None] + y_boundary_inds = y == rois[:, 3][:, None] + + x = (x - offset_x) * scale_x + x = x.floor().long() + y = (y - offset_y) * scale_y + y = y.floor().long() + + x[x_boundary_inds] = heatmap_size - 1 + y[y_boundary_inds] = heatmap_size - 1 + + valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) + vis = keypoints[..., 2] > 0 + valid = (valid_loc & vis).long() + + lin_ind = y * heatmap_size + x + heatmaps = lin_ind * valid + + return heatmaps, valid + + +@torch.jit.script_if_tracing +def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """ + Extract predicted keypoint locations from heatmaps. + + Args: + maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for + each ROI and each keypoint. + rois (Tensor): (#ROIs, 4). The box of each ROI. + + Returns: + Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to + (x, y, logit, score) for each keypoint. + + When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate, + we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from + Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + """ + + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = (rois[:, 2] - rois[:, 0]).clamp(min=1) + heights = (rois[:, 3] - rois[:, 1]).clamp(min=1) + widths_ceil = widths.ceil() + heights_ceil = heights.ceil() + + num_rois, num_keypoints = maps.shape[:2] + xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4) + + width_corrections = widths / widths_ceil + height_corrections = heights / heights_ceil + + keypoints_idx = torch.arange(num_keypoints, device=maps.device) + + for i in range(num_rois): + outsize = (int(heights_ceil[i]), int(widths_ceil[i])) + roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False) + + # Although semantically equivalent, `reshape` is used instead of `squeeze` due + # to limitation during ONNX export of `squeeze` in scripting mode + roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W + + # softmax over the spatial region + max_score, _ = roi_map.view(num_keypoints, -1).max(1) + max_score = max_score.view(num_keypoints, 1, 1) + tmp_full_resolution = (roi_map - max_score).exp_() + tmp_pool_resolution = (maps[i] - max_score).exp_() + # Produce scores over the region H x W, but normalize with POOL_H x POOL_W, + # so that the scores of objects of different absolute sizes will be more comparable + roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True) + + w = roi_map.shape[2] + pos = roi_map.view(num_keypoints, -1).argmax(1) + + x_int = pos % w + y_int = (pos - x_int) // w + + assert ( + roi_map_scores[keypoints_idx, y_int, x_int] + == roi_map_scores.view(num_keypoints, -1).max(1)[0] + ).all() + + x = (x_int.float() + 0.5) * width_corrections[i] + y = (y_int.float() + 0.5) * height_corrections[i] + + xy_preds[i, :, 0] = x + offset_x[i] + xy_preds[i, :, 1] = y + offset_y[i] + xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int] + xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int] + + return xy_preds diff --git a/data_processing/detectron2/detectron2/structures/masks.py b/data_processing/detectron2/detectron2/structures/masks.py new file mode 100644 index 0000000..899ad8b --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/masks.py @@ -0,0 +1,534 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import numpy as np +from typing import Any, Iterator, List, Union +import pycocotools.mask as mask_util +import torch +from torch import device + +from detectron2.layers.roi_align import ROIAlign +from detectron2.utils.memory import retry_if_cuda_oom + +from .boxes import Boxes + + +def polygon_area(x, y): + # Using the shoelace formula + # https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + +def polygons_to_bitmask(polygons: List[np.ndarray], height: int, width: int) -> np.ndarray: + """ + Args: + polygons (list[ndarray]): each array has shape (Nx2,) + height, width (int) + + Returns: + ndarray: a bool mask of shape (height, width) + """ + if len(polygons) == 0: + # COCOAPI does not support empty polygons + return np.zeros((height, width)).astype(bool) + rles = mask_util.frPyObjects(polygons, height, width) + rle = mask_util.merge(rles) + return mask_util.decode(rle).astype(bool) + + +def rasterize_polygons_within_box( + polygons: List[np.ndarray], box: np.ndarray, mask_size: int +) -> torch.Tensor: + """ + Rasterize the polygons into a mask image and + crop the mask content in the given box. + The cropped mask is resized to (mask_size, mask_size). + + This function is used when generating training targets for mask head in Mask R-CNN. + Given original ground-truth masks for an image, new ground-truth mask + training targets in the size of `mask_size x mask_size` + must be provided for each predicted box. This function will be called to + produce such targets. + + Args: + polygons (list[ndarray[float]]): a list of polygons, which represents an instance. + box: 4-element numpy array + mask_size (int): + + Returns: + Tensor: BoolTensor of shape (mask_size, mask_size) + """ + # 1. Shift the polygons w.r.t the boxes + w, h = box[2] - box[0], box[3] - box[1] + + polygons = copy.deepcopy(polygons) + for p in polygons: + p[0::2] = p[0::2] - box[0] + p[1::2] = p[1::2] - box[1] + + # 2. Rescale the polygons to the new box size + # max() to avoid division by small number + ratio_h = mask_size / max(h, 0.1) + ratio_w = mask_size / max(w, 0.1) + + if ratio_h == ratio_w: + for p in polygons: + p *= ratio_h + else: + for p in polygons: + p[0::2] *= ratio_w + p[1::2] *= ratio_h + + # 3. Rasterize the polygons with coco api + mask = polygons_to_bitmask(polygons, mask_size, mask_size) + mask = torch.from_numpy(mask) + return mask + + +class BitMasks: + """ + This class stores the segmentation masks for all objects in one image, in + the form of bitmaps. + + Attributes: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + + def __init__(self, tensor: Union[torch.Tensor, np.ndarray]): + """ + Args: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + if isinstance(tensor, torch.Tensor): + tensor = tensor.to(torch.bool) + else: + tensor = torch.as_tensor(tensor, dtype=torch.bool, device=torch.device("cpu")) + assert tensor.dim() == 3, tensor.size() + self.image_size = tensor.shape[1:] + self.tensor = tensor + + @torch.jit.unused + def to(self, *args: Any, **kwargs: Any) -> "BitMasks": + return BitMasks(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks": + """ + Returns: + BitMasks: Create a new :class:`BitMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask. + 2. `new_masks = masks[2:10]`: return a slice of masks. + 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return BitMasks(self.tensor[item].unsqueeze(0)) + m = self.tensor[item] + assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format( + item, m.shape + ) + return BitMasks(m) + + @torch.jit.unused + def __iter__(self) -> torch.Tensor: + yield from self.tensor + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + def __len__(self) -> int: + return self.tensor.shape[0] + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: a BoolTensor which represents + whether each mask is empty (False) or non-empty (True). + """ + return self.tensor.flatten(1).any(dim=1) + + @staticmethod + def from_polygon_masks( + polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]], height: int, width: int + ) -> "BitMasks": + """ + Args: + polygon_masks (list[list[ndarray]] or PolygonMasks) + height, width (int) + """ + if isinstance(polygon_masks, PolygonMasks): + polygon_masks = polygon_masks.polygons + masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks] + if len(masks): + return BitMasks(torch.stack([torch.from_numpy(x) for x in masks])) + else: + return BitMasks(torch.empty(0, height, width, dtype=torch.bool)) + + @staticmethod + def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks": + """ + Args: + roi_masks: + height, width (int): + """ + return roi_masks.to_bitmasks(height, width) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each bitmask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + It has less reconstruction error compared to rasterization with polygons. + However we observe no difference in accuracy, + but BitMasks requires more memory to store all the masks. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: + A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + device = self.tensor.device + + batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[:, None] + rois = torch.cat([batch_inds, boxes], dim=1) # Nx5 + + bit_masks = self.tensor.to(dtype=torch.float32) + rois = rois.to(device=device) + output = ( + ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True) + .forward(bit_masks[:, None, :, :], rois) + .squeeze(1) + ) + output = output >= 0.5 + return output + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around bitmasks. + If a mask is empty, it's bounding box will be all zero. + """ + boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32) + x_any = torch.any(self.tensor, dim=1) + y_any = torch.any(self.tensor, dim=2) + for idx in range(self.tensor.shape[0]): + x = torch.where(x_any[idx, :])[0] + y = torch.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + boxes[idx, :] = torch.as_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32 + ) + return Boxes(boxes) + + @staticmethod + def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks": + """ + Concatenates a list of BitMasks into a single BitMasks + + Arguments: + bitmasks_list (list[BitMasks]) + + Returns: + BitMasks: the concatenated BitMasks + """ + assert isinstance(bitmasks_list, (list, tuple)) + assert len(bitmasks_list) > 0 + assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list) + + cat_bitmasks = type(bitmasks_list[0])(torch.cat([bm.tensor for bm in bitmasks_list], dim=0)) + return cat_bitmasks + + +class PolygonMasks: + """ + This class stores the segmentation masks for all objects in one image, in the form of polygons. + + Attributes: + polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon. + """ + + def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]): + """ + Arguments: + polygons (list[list[np.ndarray]]): The first + level of the list correspond to individual instances, + the second level to all the polygons that compose the + instance, and the third level to the polygon coordinates. + The third level array should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + """ + if not isinstance(polygons, list): + raise ValueError( + "Cannot create PolygonMasks: Expect a list of list of polygons per image. " + "Got '{}' instead.".format(type(polygons)) + ) + + def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + # Use float64 for higher precision, because why not? + # Always put polygons on CPU (self.to is a no-op) since they + # are supposed to be small tensors. + # May need to change this assumption if GPU placement becomes useful + if isinstance(t, torch.Tensor): + t = t.cpu().numpy() + return np.asarray(t).astype("float64") + + def process_polygons( + polygons_per_instance: List[Union[torch.Tensor, np.ndarray]] + ) -> List[np.ndarray]: + if not isinstance(polygons_per_instance, list): + raise ValueError( + "Cannot create polygons: Expect a list of polygons per instance. " + "Got '{}' instead.".format(type(polygons_per_instance)) + ) + # transform each polygon to a numpy array + polygons_per_instance = [_make_array(p) for p in polygons_per_instance] + for polygon in polygons_per_instance: + if len(polygon) % 2 != 0 or len(polygon) < 6: + raise ValueError(f"Cannot create a polygon from {len(polygon)} coordinates.") + return polygons_per_instance + + self.polygons: List[List[np.ndarray]] = [ + process_polygons(polygons_per_instance) for polygons_per_instance in polygons + ] + + def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks": + return self + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around polygon masks. + """ + boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32) + for idx, polygons_per_instance in enumerate(self.polygons): + minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32) + maxxy = torch.zeros(2, dtype=torch.float32) + for polygon in polygons_per_instance: + coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32) + minxy = torch.min(minxy, torch.min(coords, dim=0).values) + maxxy = torch.max(maxxy, torch.max(coords, dim=0).values) + boxes[idx, :2] = minxy + boxes[idx, 2:] = maxxy + return Boxes(boxes) + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: + a BoolTensor which represents whether each mask is empty (False) or not (True). + """ + keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons] + return torch.from_numpy(np.asarray(keep, dtype=bool)) + + def __getitem__(self, item: Union[int, slice, List[int], torch.BoolTensor]) -> "PolygonMasks": + """ + Support indexing over the instances and return a `PolygonMasks` object. + `item` can be: + + 1. An integer. It will return an object with only one instance. + 2. A slice. It will return an object with the selected instances. + 3. A list[int]. It will return an object with the selected instances, + correpsonding to the indices in the list. + 4. A vector mask of type BoolTensor, whose length is num_instances. + It will return an object with the instances whose mask is nonzero. + """ + if isinstance(item, int): + selected_polygons = [self.polygons[item]] + elif isinstance(item, slice): + selected_polygons = self.polygons[item] + elif isinstance(item, list): + selected_polygons = [self.polygons[i] for i in item] + elif isinstance(item, torch.Tensor): + # Polygons is a list, so we have to move the indices back to CPU. + if item.dtype == torch.bool: + assert item.dim() == 1, item.shape + item = item.nonzero().squeeze(1).cpu().numpy().tolist() + elif item.dtype in [torch.int32, torch.int64]: + item = item.cpu().numpy().tolist() + else: + raise ValueError("Unsupported tensor dtype={} for indexing!".format(item.dtype)) + selected_polygons = [self.polygons[i] for i in item] + return PolygonMasks(selected_polygons) + + def __iter__(self) -> Iterator[List[np.ndarray]]: + """ + Yields: + list[ndarray]: the polygons for one instance. + Each Tensor is a float64 vector representing a polygon. + """ + return iter(self.polygons) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.polygons)) + return s + + def __len__(self) -> int: + return len(self.polygons) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each mask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + + device = boxes.device + # Put boxes on the CPU, as the polygon representation is not efficient GPU-wise + # (several small tensors for representing a single instance mask) + boxes = boxes.to(torch.device("cpu")) + + results = [ + rasterize_polygons_within_box(poly, box.numpy(), mask_size) + for poly, box in zip(self.polygons, boxes) + ] + """ + poly: list[list[float]], the polygons for one instance + box: a tensor of shape (4,) + """ + if len(results) == 0: + return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device) + return torch.stack(results, dim=0).to(device=device) + + def area(self): + """ + Computes area of the mask. + Only works with Polygons, using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Returns: + Tensor: a vector, area for each instance + """ + + area = [] + for polygons_per_instance in self.polygons: + area_per_instance = 0 + for p in polygons_per_instance: + area_per_instance += polygon_area(p[0::2], p[1::2]) + area.append(area_per_instance) + + return torch.tensor(area) + + @staticmethod + def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks": + """ + Concatenates a list of PolygonMasks into a single PolygonMasks + + Arguments: + polymasks_list (list[PolygonMasks]) + + Returns: + PolygonMasks: the concatenated PolygonMasks + """ + assert isinstance(polymasks_list, (list, tuple)) + assert len(polymasks_list) > 0 + assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list) + + cat_polymasks = type(polymasks_list[0])( + list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list)) + ) + return cat_polymasks + + +class ROIMasks: + """ + Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given, + full-image bitmask can be obtained by "pasting" the mask on the region defined + by the corresponding ROI box. + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor: (N, M, M) mask tensor that defines the mask within each ROI. + """ + if tensor.dim() != 3: + raise ValueError("ROIMasks must take a masks of 3 dimension.") + self.tensor = tensor + + def to(self, device: torch.device) -> "ROIMasks": + return ROIMasks(self.tensor.to(device)) + + @property + def device(self) -> device: + return self.tensor.device + + def __len__(self): + return self.tensor.shape[0] + + def __getitem__(self, item) -> "ROIMasks": + """ + Returns: + ROIMasks: Create a new :class:`ROIMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[2:10]`: return a slice of masks. + 2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + t = self.tensor[item] + if t.dim() != 3: + raise ValueError( + f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!" + ) + return ROIMasks(t) + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @torch.jit.unused + def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5): + """ + Args: see documentation of :func:`paste_masks_in_image`. + """ + from detectron2.layers.mask_ops import paste_masks_in_image, _paste_masks_tensor_shape + + if torch.jit.is_tracing(): + if isinstance(height, torch.Tensor): + paste_func = _paste_masks_tensor_shape + else: + paste_func = paste_masks_in_image + else: + paste_func = retry_if_cuda_oom(paste_masks_in_image) + bitmasks = paste_func(self.tensor, boxes.tensor, (height, width), threshold=threshold) + return BitMasks(bitmasks) diff --git a/data_processing/detectron2/detectron2/structures/rotated_boxes.py b/data_processing/detectron2/detectron2/structures/rotated_boxes.py new file mode 100644 index 0000000..c842b99 --- /dev/null +++ b/data_processing/detectron2/detectron2/structures/rotated_boxes.py @@ -0,0 +1,505 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Tuple +import torch + +from detectron2.layers.rotated_boxes import pairwise_iou_rotated + +from .boxes import Boxes + + +class RotatedBoxes(Boxes): + """ + This structure stores a list of rotated boxes as a Nx5 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx5 matrix. Each row is + (x_center, y_center, width, height, angle), + in which angle is represented in degrees. + While there's no strict range restriction for it, + the recommended principal range is between [-180, 180) degrees. + + Assume we have a horizontal box B = (x_center, y_center, width, height), + where width is along the x-axis and height is along the y-axis. + The rotated box B_rot (x_center, y_center, width, height, angle) + can be seen as: + + 1. When angle == 0: + B_rot == B + 2. When angle > 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW; + 3. When angle < 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW. + + Mathematically, since the right-handed coordinate system for image space + is (y, x), where y is top->down and x is left->right, the 4 vertices of the + rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from + the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4) + in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians, + :math:`(y_c, x_c)` is the center of the rectangle): + + .. math:: + + yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c, + + xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c, + + which is the standard rigid-body rotation transformation. + + Intuitively, the angle is + (1) the rotation angle from y-axis in image space + to the height vector (top->down in the box's local coordinate system) + of the box in CCW, and + (2) the rotation angle from x-axis in image space + to the width vector (left->right in the box's local coordinate system) + of the box in CCW. + + More intuitively, consider the following horizontal box ABCD represented + in (x1, y1, x2, y2): (3, 2, 7, 4), + covering the [3, 7] x [2, 4] region of the continuous coordinate system + which looks like this: + + .. code:: none + + O--------> x + | + | A---B + | | | + | D---C + | + v y + + Note that each capital letter represents one 0-dimensional geometric point + instead of a 'square pixel' here. + + In the example above, using (x, y) to represent a point we have: + + .. math:: + + O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4) + + We name vector AB = vector DC as the width vector in box's local coordinate system, and + vector AD = vector BC as the height vector in box's local coordinate system. Initially, + when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis + in the image space, respectively. + + For better illustration, we denote the center of the box as E, + + .. code:: none + + O--------> x + | + | A---B + | | E | + | D---C + | + v y + + where the center E = ((3+7)/2, (2+4)/2) = (5, 3). + + Also, + + .. math:: + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Therefore, the corresponding representation for the same shape in rotated box in + (x_center, y_center, width, height, angle) format is: + + (5, 3, 4, 2, 0), + + Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees + CCW (counter-clockwise) by definition. It looks like this: + + .. code:: none + + O--------> x + | B-C + | | | + | |E| + | | | + | A-D + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CCW with regard to E: + A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5) + + Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to + vector AD or vector BC (the top->down height vector in box's local coordinate system), + or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right + width vector in box's local coordinate system). + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise) + by definition? It looks like this: + + .. code:: none + + O--------> x + | D-A + | | | + | |E| + | | | + | C-B + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CW with regard to E: + A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1) + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU + will be 1. However, these two will generate different RoI Pooling results and + should not be treated as an identical box. + + On the other hand, it's easy to see that (X, Y, W, H, A) is identical to + (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be + identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is + equivalent to rotating the same shape 90 degrees CW. + + We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180): + + .. code:: none + + O--------> x + | + | C---D + | | E | + | B---A + | + v y + + .. math:: + + A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2), + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Finally, this is a very inaccurate (heavily quantized) illustration of + how (5, 3, 4, 2, 60) looks like in case anyone wonders: + + .. code:: none + + O--------> x + | B\ + | / C + | /E / + | A / + | `D + v y + + It's still a rectangle with center of (5, 3), width of 4 and height of 2, + but its angle (and thus orientation) is somewhere between + (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90). + """ + device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device) + assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size() + + self.tensor = tensor + + def clone(self) -> "RotatedBoxes": + """ + Clone the RotatedBoxes. + + Returns: + RotatedBoxes + """ + return RotatedBoxes(self.tensor.clone()) + + def to(self, device: torch.device): + # Boxes are assumed float32 and does not support to(dtype) + return RotatedBoxes(self.tensor.to(device=device)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = box[:, 2] * box[:, 3] + return area + + # Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor + def normalize_angles(self) -> None: + """ + Restrict angles to the range of [-180, 180) degrees + """ + angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0 + self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1) + + def clip(self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + For RRPN: + Only clip boxes that are almost horizontal with a tolerance of + clip_angle_threshold to maintain backward compatibility. + + Rotated boxes beyond this threshold are not clipped for two reasons: + + 1. There are potentially multiple ways to clip a rotated box to make it + fit within the image. + 2. It's tricky to make the entire rectangular box fit within the image + and still be able to not leave out pixels of interest. + + Therefore we rely on ops like RoIAlignRotated to safely handle this. + + Args: + box_size (height, width): The clipping box's size. + clip_angle_threshold: + Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees), + we do the clipping as horizontal boxes. + """ + h, w = box_size + + # normalize angles to be within (-180, 180] degrees + self.normalize_angles() + + idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0] + + # convert to (x1, y1, x2, y2) + x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0 + y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0 + x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0 + y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0 + + # clip + x1.clamp_(min=0, max=w) + y1.clamp_(min=0, max=h) + x2.clamp_(min=0, max=w) + y2.clamp_(min=0, max=h) + + # convert back to (xc, yc, w, h) + self.tensor[idx, 0] = (x1 + x2) / 2.0 + self.tensor[idx, 1] = (y1 + y2) / 2.0 + # make sure widths and heights do not increase due to numerical errors + self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1) + self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: a binary vector which represents + whether each box is empty (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] + heights = box[:, 3] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "RotatedBoxes": + """ + Returns: + RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned RotatedBoxes might share storage with this RotatedBoxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return RotatedBoxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert b.dim() == 2, "Indexing on RotatedBoxes with {} failed to return a matrix!".format( + item + ) + return RotatedBoxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "RotatedBoxes(" + str(self.tensor) + ")" + + def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box covering + [0, width] x [0, height] + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + For RRPN, it might not be necessary to call this function since it's common + for rotated box to extend to outside of the image boundaries + (the clip function only clips the near-horizontal boxes) + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + + cnt_x = self.tensor[..., 0] + cnt_y = self.tensor[..., 1] + half_w = self.tensor[..., 2] / 2.0 + half_h = self.tensor[..., 3] / 2.0 + a = self.tensor[..., 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + max_rect_dx = c * half_w + s * half_h + max_rect_dy = c * half_h + s * half_w + + inds_inside = ( + (cnt_x - max_rect_dx >= -boundary_threshold) + & (cnt_y - max_rect_dy >= -boundary_threshold) + & (cnt_x + max_rect_dx < width + boundary_threshold) + & (cnt_y + max_rect_dy < height + boundary_threshold) + ) + + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return self.tensor[:, :2] + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the rotated box with horizontal and vertical scaling factors + Note: when scale_factor_x != scale_factor_y, + the rotated box does not preserve the rectangular shape when the angle + is not a multiple of 90 degrees under resize transformation. + Instead, the shape is a parallelogram (that has skew) + Here we make an approximation by fitting a rotated rectangle to the parallelogram. + """ + self.tensor[:, 0] *= scale_x + self.tensor[:, 1] *= scale_y + theta = self.tensor[:, 4] * math.pi / 180.0 + c = torch.cos(theta) + s = torch.sin(theta) + + # In image space, y is top->down and x is left->right + # Consider the local coordintate system for the rotated box, + # where the box center is located at (0, 0), and the four vertices ABCD are + # A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2) + # the midpoint of the left edge AD of the rotated box E is: + # E = (A+D)/2 = (-w / 2, 0) + # the midpoint of the top edge AB of the rotated box F is: + # F(0, -h / 2) + # To get the old coordinates in the global system, apply the rotation transformation + # (Note: the right-handed coordinate system for image space is yOx): + # (old_x, old_y) = (s * y + c * x, c * y - s * x) + # E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2) + # F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2) + # After applying the scaling factor (sfx, sfy): + # E(new) = (-sfx * c * w / 2, sfy * s * w / 2) + # F(new) = (-sfx * s * h / 2, -sfy * c * h / 2) + # The new width after scaling tranformation becomes: + + # w(new) = |E(new) - O| * 2 + # = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2 + # = sqrt[(sfx * c)^2 + (sfy * s)^2] * w + # i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y + self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2) + + # h(new) = |F(new) - O| * 2 + # = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2 + # = sqrt[(sfx * s)^2 + (sfy * c)^2] * h + # i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x + self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2) + + # The angle is the rotation angle from y-axis in image space to the height + # vector (top->down in the box's local coordinate system) of the box in CCW. + # + # angle(new) = angle_yOx(O - F(new)) + # = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) ) + # = atan2(sfx * s * h / 2, sfy * c * h / 2) + # = atan2(sfx * s, sfy * c) + # + # For example, + # when sfx == sfy, angle(new) == atan2(s, c) == angle(old) + self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi + + @classmethod + def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes": + """ + Concatenates a list of RotatedBoxes into a single RotatedBoxes + + Arguments: + boxes_list (list[RotatedBoxes]) + + Returns: + RotatedBoxes: the concatenated RotatedBoxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, RotatedBoxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (5,) at a time. + """ + yield from self.tensor + + +def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None: + """ + Given two lists of rotated boxes of size N and M, + compute the IoU (intersection over union) + between **all** N x M pairs of boxes. + The box order must be (x_center, y_center, width, height, angle). + + Args: + boxes1, boxes2 (RotatedBoxes): + two `RotatedBoxes`. Contains N & M rotated boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + + return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor) diff --git a/data_processing/detectron2/detectron2/tracking/__init__.py b/data_processing/detectron2/detectron2/tracking/__init__.py new file mode 100644 index 0000000..21078ae --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .base_tracker import ( # noqa + BaseTracker, + build_tracker_head, + TRACKER_HEADS_REGISTRY, +) +from .bbox_iou_tracker import BBoxIOUTracker # noqa +from .hungarian_tracker import BaseHungarianTracker # noqa +from .iou_weighted_hungarian_bbox_iou_tracker import ( # noqa + IOUWeightedHungarianBBoxIOUTracker, +) +from .utils import create_prediction_pairs # noqa +from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker # noqa + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/detectron2/tracking/base_tracker.py b/data_processing/detectron2/detectron2/tracking/base_tracker.py new file mode 100644 index 0000000..a8872f7 --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/base_tracker.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +from detectron2.config import configurable +from detectron2.utils.registry import Registry + +from ..config.config import CfgNode as CfgNode_ +from ..structures import Instances + +TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") +TRACKER_HEADS_REGISTRY.__doc__ = """ +Registry for tracking classes. +""" + + +class BaseTracker(object): + """ + A parent class for all trackers + """ + + @configurable + def __init__(self, **kwargs): + self._prev_instances = None # (D2)instances for previous frame + self._matched_idx = set() # indices in prev_instances found matching + self._matched_ID = set() # idendities in prev_instances found matching + self._untracked_prev_idx = set() # indices in prev_instances not found matching + self._id_count = 0 # used to assign new id + + @classmethod + def from_config(cls, cfg: CfgNode_): + raise NotImplementedError("Calling BaseTracker::from_config") + + def update(self, predictions: Instances) -> Instances: + """ + Args: + predictions: D2 Instances for predictions of the current frame + Return: + D2 Instances for predictions of the current frame with ID assigned + + _prev_instances and instances will have the following fields: + .pred_boxes (shape=[N, 4]) + .scores (shape=[N,]) + .pred_classes (shape=[N,]) + .pred_keypoints (shape=[N, M, 3], Optional) + .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] + .ID (shape=[N,]) + + N: # of detected bboxes + H and W: height and width of 2D mask + """ + raise NotImplementedError("Calling BaseTracker::update") + + +def build_tracker_head(cfg: CfgNode_) -> BaseTracker: + """ + Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. + + Args: + cfg: D2 CfgNode, config file with tracker information + Return: + tracker object + """ + name = cfg.TRACKER_HEADS.TRACKER_NAME + tracker_class = TRACKER_HEADS_REGISTRY.get(name) + return tracker_class(cfg) diff --git a/data_processing/detectron2/detectron2/tracking/bbox_iou_tracker.py b/data_processing/detectron2/detectron2/tracking/bbox_iou_tracker.py new file mode 100644 index 0000000..598081c --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/bbox_iou_tracker.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +import copy +import numpy as np +from typing import List +import torch + +from detectron2.config import configurable +from detectron2.structures import Boxes, Instances +from detectron2.structures.boxes import pairwise_iou + +from ..config.config import CfgNode as CfgNode_ +from .base_tracker import TRACKER_HEADS_REGISTRY, BaseTracker + + +@TRACKER_HEADS_REGISTRY.register() +class BBoxIOUTracker(BaseTracker): + """ + A bounding box tracker to assign ID based on IoU between current and previous instances + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__(**kwargs) + self._video_height = video_height + self._video_width = video_width + self._max_num_instances = max_num_instances + self._max_lost_frame_count = max_lost_frame_count + self._min_box_rel_dim = min_box_rel_dim + self._min_instance_period = min_instance_period + self._track_iou_threshold = track_iou_threshold + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def update(self, instances: Instances) -> Instances: + """ + See BaseTracker description + """ + instances = self._initialize_extra_fields(instances) + if self._prev_instances is not None: + # calculate IoU of all bbox pairs + iou_all = pairwise_iou( + boxes1=instances.pred_boxes, + boxes2=self._prev_instances.pred_boxes, + ) + # sort IoU in descending order + bbox_pairs = self._create_prediction_pairs(instances, iou_all) + # assign previous ID to current bbox if IoU > track_iou_threshold + self._reset_fields() + for bbox_pair in bbox_pairs: + idx = bbox_pair["idx"] + prev_id = bbox_pair["prev_id"] + if ( + idx in self._matched_idx + or prev_id in self._matched_ID + or bbox_pair["IoU"] < self._track_iou_threshold + ): + continue + instances.ID[idx] = prev_id + instances.ID_period[idx] = bbox_pair["prev_period"] + 1 + instances.lost_frame_count[idx] = 0 + self._matched_idx.add(idx) + self._matched_ID.add(prev_id) + self._untracked_prev_idx.remove(bbox_pair["prev_idx"]) + instances = self._assign_new_id(instances) + instances = self._merge_untracked_instances(instances) + self._prev_instances = copy.deepcopy(instances) + return instances + + def _create_prediction_pairs(self, instances: Instances, iou_all: np.ndarray) -> List: + """ + For all instances in previous and current frames, create pairs. For each + pair, store index of the instance in current frame predcitions, index in + previous predictions, ID in previous predictions, IoU of the bboxes in this + pair, period in previous predictions. + + Args: + instances: D2 Instances, for predictions of the current frame + iou_all: IoU for all bboxes pairs + Return: + A list of IoU for all pairs + """ + bbox_pairs = [] + for i in range(len(instances)): + for j in range(len(self._prev_instances)): + bbox_pairs.append( + { + "idx": i, + "prev_idx": j, + "prev_id": self._prev_instances.ID[j], + "IoU": iou_all[i, j], + "prev_period": self._prev_instances.ID_period[j], + } + ) + return bbox_pairs + + def _initialize_extra_fields(self, instances: Instances) -> Instances: + """ + If input instances don't have ID, ID_period, lost_frame_count fields, + this method is used to initialize these fields. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with extra fields added + """ + if not instances.has("ID"): + instances.set("ID", [None] * len(instances)) + if not instances.has("ID_period"): + instances.set("ID_period", [None] * len(instances)) + if not instances.has("lost_frame_count"): + instances.set("lost_frame_count", [None] * len(instances)) + if self._prev_instances is None: + instances.ID = list(range(len(instances))) + self._id_count += len(instances) + instances.ID_period = [1] * len(instances) + instances.lost_frame_count = [0] * len(instances) + return instances + + def _reset_fields(self): + """ + Before each uodate call, reset fields first + """ + self._matched_idx = set() + self._matched_ID = set() + self._untracked_prev_idx = set(range(len(self._prev_instances))) + + def _assign_new_id(self, instances: Instances) -> Instances: + """ + For each untracked instance, assign a new id + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with new ID assigned + """ + untracked_idx = set(range(len(instances))).difference(self._matched_idx) + for idx in untracked_idx: + instances.ID[idx] = self._id_count + self._id_count += 1 + instances.ID_period[idx] = 1 + instances.lost_frame_count[idx] = 0 + return instances + + def _merge_untracked_instances(self, instances: Instances) -> Instances: + """ + For untracked previous instances, under certain condition, still keep them + in tracking and merge with the current instances. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances merging current instances and instances from previous + frame decided to keep tracking + """ + untracked_instances = Instances( + image_size=instances.image_size, + pred_boxes=[], + pred_classes=[], + scores=[], + ID=[], + ID_period=[], + lost_frame_count=[], + ) + prev_bboxes = list(self._prev_instances.pred_boxes) + prev_classes = list(self._prev_instances.pred_classes) + prev_scores = list(self._prev_instances.scores) + prev_ID_period = self._prev_instances.ID_period + if instances.has("pred_masks"): + untracked_instances.set("pred_masks", []) + prev_masks = list(self._prev_instances.pred_masks) + if instances.has("pred_keypoints"): + untracked_instances.set("pred_keypoints", []) + prev_keypoints = list(self._prev_instances.pred_keypoints) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.set("pred_keypoint_heatmaps", []) + prev_keypoint_heatmaps = list(self._prev_instances.pred_keypoint_heatmaps) + for idx in self._untracked_prev_idx: + x_left, y_top, x_right, y_bot = prev_bboxes[idx] + if ( + (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) + or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) + or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count + or prev_ID_period[idx] <= self._min_instance_period + ): + continue + untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) + untracked_instances.pred_classes.append(int(prev_classes[idx])) + untracked_instances.scores.append(float(prev_scores[idx])) + untracked_instances.ID.append(self._prev_instances.ID[idx]) + untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) + untracked_instances.lost_frame_count.append( + self._prev_instances.lost_frame_count[idx] + 1 + ) + if instances.has("pred_masks"): + untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) + if instances.has("pred_keypoints"): + untracked_instances.pred_keypoints.append( + prev_keypoints[idx].numpy().astype(np.uint8) + ) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.pred_keypoint_heatmaps.append( + prev_keypoint_heatmaps[idx].numpy().astype(np.float32) + ) + untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) + untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) + untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) + if instances.has("pred_masks"): + untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) + if instances.has("pred_keypoints"): + untracked_instances.pred_keypoints = torch.IntTensor(untracked_instances.pred_keypoints) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.pred_keypoint_heatmaps = torch.FloatTensor( + untracked_instances.pred_keypoint_heatmaps + ) + + return Instances.cat( + [ + instances, + untracked_instances, + ] + ) diff --git a/data_processing/detectron2/detectron2/tracking/hungarian_tracker.py b/data_processing/detectron2/detectron2/tracking/hungarian_tracker.py new file mode 100644 index 0000000..5b3ce88 --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/hungarian_tracker.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +import copy +import numpy as np +from typing import Dict +import torch +from scipy.optimize import linear_sum_assignment + +from detectron2.config import configurable +from detectron2.structures import Boxes, Instances + +from ..config.config import CfgNode as CfgNode_ +from .base_tracker import BaseTracker + + +class BaseHungarianTracker(BaseTracker): + """ + A base class for all Hungarian trackers + """ + + @configurable + def __init__( + self, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + **kwargs + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + """ + super().__init__(**kwargs) + self._video_height = video_height + self._video_width = video_width + self._max_num_instances = max_num_instances + self._max_lost_frame_count = max_lost_frame_count + self._min_box_rel_dim = min_box_rel_dim + self._min_instance_period = min_instance_period + + @classmethod + def from_config(cls, cfg: CfgNode_) -> Dict: + raise NotImplementedError("Calling HungarianTracker::from_config") + + def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: + raise NotImplementedError("Calling HungarianTracker::build_matrix") + + def update(self, instances: Instances) -> Instances: + if instances.has("pred_keypoints"): + raise NotImplementedError("Need to add support for keypoints") + instances = self._initialize_extra_fields(instances) + if self._prev_instances is not None: + self._untracked_prev_idx = set(range(len(self._prev_instances))) + cost_matrix = self.build_cost_matrix(instances, self._prev_instances) + matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) + instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) + instances = self._process_unmatched_idx(instances, matched_idx) + instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) + self._prev_instances = copy.deepcopy(instances) + return instances + + def _initialize_extra_fields(self, instances: Instances) -> Instances: + """ + If input instances don't have ID, ID_period, lost_frame_count fields, + this method is used to initialize these fields. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with extra fields added + """ + if not instances.has("ID"): + instances.set("ID", [None] * len(instances)) + if not instances.has("ID_period"): + instances.set("ID_period", [None] * len(instances)) + if not instances.has("lost_frame_count"): + instances.set("lost_frame_count", [None] * len(instances)) + if self._prev_instances is None: + instances.ID = list(range(len(instances))) + self._id_count += len(instances) + instances.ID_period = [1] * len(instances) + instances.lost_frame_count = [0] * len(instances) + return instances + + def _process_matched_idx( + self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray + ) -> Instances: + assert matched_idx.size == matched_prev_idx.size + for i in range(matched_idx.size): + instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] + instances.ID_period[matched_idx[i]] = ( + self._prev_instances.ID_period[matched_prev_idx[i]] + 1 + ) + instances.lost_frame_count[matched_idx[i]] = 0 + return instances + + def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: + untracked_idx = set(range(len(instances))).difference(set(matched_idx)) + for idx in untracked_idx: + instances.ID[idx] = self._id_count + self._id_count += 1 + instances.ID_period[idx] = 1 + instances.lost_frame_count[idx] = 0 + return instances + + def _process_unmatched_prev_idx( + self, instances: Instances, matched_prev_idx: np.ndarray + ) -> Instances: + untracked_instances = Instances( + image_size=instances.image_size, + pred_boxes=[], + pred_masks=[], + pred_classes=[], + scores=[], + ID=[], + ID_period=[], + lost_frame_count=[], + ) + prev_bboxes = list(self._prev_instances.pred_boxes) + prev_classes = list(self._prev_instances.pred_classes) + prev_scores = list(self._prev_instances.scores) + prev_ID_period = self._prev_instances.ID_period + if instances.has("pred_masks"): + prev_masks = list(self._prev_instances.pred_masks) + untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) + for idx in untracked_prev_idx: + x_left, y_top, x_right, y_bot = prev_bboxes[idx] + if ( + (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) + or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) + or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count + or prev_ID_period[idx] <= self._min_instance_period + ): + continue + untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) + untracked_instances.pred_classes.append(int(prev_classes[idx])) + untracked_instances.scores.append(float(prev_scores[idx])) + untracked_instances.ID.append(self._prev_instances.ID[idx]) + untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) + untracked_instances.lost_frame_count.append( + self._prev_instances.lost_frame_count[idx] + 1 + ) + if instances.has("pred_masks"): + untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) + + untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) + untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) + untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) + if instances.has("pred_masks"): + untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) + else: + untracked_instances.remove("pred_masks") + + return Instances.cat( + [ + instances, + untracked_instances, + ] + ) diff --git a/data_processing/detectron2/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py b/data_processing/detectron2/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000..b3b4d1c --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import numpy as np +from typing import List + +from detectron2.config import CfgNode as CfgNode_ +from detectron2.config import configurable + +from .base_tracker import TRACKER_HEADS_REGISTRY +from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker + + +@TRACKER_HEADS_REGISTRY.register() +class IOUWeightedHungarianBBoxIOUTracker(VanillaHungarianBBoxIOUTracker): + """ + A tracker using IoU as weight in Hungarian algorithm, also known + as Munkres or Kuhn-Munkres algorithm + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__( + video_height=video_height, + video_width=video_width, + max_num_instances=max_num_instances, + max_lost_frame_count=max_lost_frame_count, + min_box_rel_dim=min_box_rel_dim, + min_instance_period=min_instance_period, + track_iou_threshold=track_iou_threshold, + ) + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray: + """ + Based on IoU for each pair of bbox, assign the associated value in cost matrix + + Args: + cost_matrix: np.ndarray, initialized 2D array with target dimensions + bbox_pairs: list of bbox pair, in each pair, iou value is stored + Return: + np.ndarray, cost_matrix with assigned values + """ + for pair in bbox_pairs: + # assign (-1 * IoU) for above threshold pairs, algorithms will minimize cost + cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 * pair["IoU"] + return cost_matrix diff --git a/data_processing/detectron2/detectron2/tracking/utils.py b/data_processing/detectron2/detectron2/tracking/utils.py new file mode 100644 index 0000000..92634c5 --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/utils.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +import numpy as np +from typing import List + +from detectron2.structures import Instances + + +def create_prediction_pairs( + instances: Instances, + prev_instances: Instances, + iou_all: np.ndarray, + threshold: float = 0.5, +) -> List: + """ + Args: + instances: predictions from current frame + prev_instances: predictions from previous frame + iou_all: 2D numpy array containing iou for each bbox pair + threshold: below the threshold, doesn't consider the pair of bbox is valid + Return: + List of bbox pairs + """ + bbox_pairs = [] + for i in range(len(instances)): + for j in range(len(prev_instances)): + if iou_all[i, j] < threshold: + continue + bbox_pairs.append( + { + "idx": i, + "prev_idx": j, + "prev_id": prev_instances.ID[j], + "IoU": iou_all[i, j], + "prev_period": prev_instances.ID_period[j], + } + ) + return bbox_pairs + + +LARGE_COST_VALUE = 100000 diff --git a/data_processing/detectron2/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py b/data_processing/detectron2/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000..5629f73 --- /dev/null +++ b/data_processing/detectron2/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import numpy as np +from typing import List + +from detectron2.config import CfgNode as CfgNode_ +from detectron2.config import configurable +from detectron2.structures import Instances +from detectron2.structures.boxes import pairwise_iou +from detectron2.tracking.utils import LARGE_COST_VALUE, create_prediction_pairs + +from .base_tracker import TRACKER_HEADS_REGISTRY +from .hungarian_tracker import BaseHungarianTracker + + +@TRACKER_HEADS_REGISTRY.register() +class VanillaHungarianBBoxIOUTracker(BaseHungarianTracker): + """ + Hungarian algo based tracker using bbox iou as metric + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__( + video_height=video_height, + video_width=video_width, + max_num_instances=max_num_instances, + max_lost_frame_count=max_lost_frame_count, + min_box_rel_dim=min_box_rel_dim, + min_instance_period=min_instance_period, + ) + self._track_iou_threshold = track_iou_threshold + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: + """ + Build the cost matrix for assignment problem + (https://en.wikipedia.org/wiki/Assignment_problem) + + Args: + instances: D2 Instances, for current frame predictions + prev_instances: D2 Instances, for previous frame predictions + + Return: + the cost matrix in numpy array + """ + assert instances is not None and prev_instances is not None + # calculate IoU of all bbox pairs + iou_all = pairwise_iou( + boxes1=instances.pred_boxes, + boxes2=self._prev_instances.pred_boxes, + ) + bbox_pairs = create_prediction_pairs( + instances, self._prev_instances, iou_all, threshold=self._track_iou_threshold + ) + # assign large cost value to make sure pair below IoU threshold won't be matched + cost_matrix = np.full((len(instances), len(prev_instances)), LARGE_COST_VALUE) + return self.assign_cost_matrix_values(cost_matrix, bbox_pairs) + + def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray: + """ + Based on IoU for each pair of bbox, assign the associated value in cost matrix + + Args: + cost_matrix: np.ndarray, initialized 2D array with target dimensions + bbox_pairs: list of bbox pair, in each pair, iou value is stored + Return: + np.ndarray, cost_matrix with assigned values + """ + for pair in bbox_pairs: + # assign -1 for IoU above threshold pairs, algorithms will minimize cost + cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 + return cost_matrix diff --git a/data_processing/detectron2/detectron2/utils/README.md b/data_processing/detectron2/detectron2/utils/README.md new file mode 100644 index 0000000..9765b24 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/README.md @@ -0,0 +1,5 @@ +# Utility functions + +This folder contain utility functions that are not used in the +core library, but are useful for building models or training +code using the config system. diff --git a/data_processing/detectron2/detectron2/utils/__init__.py b/data_processing/detectron2/detectron2/utils/__init__.py new file mode 100644 index 0000000..9020c2d --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/data_processing/detectron2/detectron2/utils/analysis.py b/data_processing/detectron2/detectron2/utils/analysis.py new file mode 100644 index 0000000..178da79 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/analysis.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# -*- coding: utf-8 -*- + +import typing +from typing import Any, List +import fvcore +from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table +from torch import nn + +from detectron2.export import TracingAdapter + +__all__ = [ + "activation_count_operators", + "flop_count_operators", + "parameter_count_table", + "parameter_count", + "FlopCountAnalysis", +] + +FLOPS_MODE = "flops" +ACTIVATIONS_MODE = "activations" + + +# Some extra ops to ignore from counting, including elementwise and reduction ops +_IGNORED_OPS = { + "aten::add", + "aten::add_", + "aten::argmax", + "aten::argsort", + "aten::batch_norm", + "aten::constant_pad_nd", + "aten::div", + "aten::div_", + "aten::exp", + "aten::log2", + "aten::max_pool2d", + "aten::meshgrid", + "aten::mul", + "aten::mul_", + "aten::neg", + "aten::nonzero_numpy", + "aten::reciprocal", + "aten::repeat_interleave", + "aten::rsub", + "aten::sigmoid", + "aten::sigmoid_", + "aten::softmax", + "aten::sort", + "aten::sqrt", + "aten::sub", + "torchvision::nms", # TODO estimate flop for nms +} + + +class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis): + """ + Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models. + """ + + def __init__(self, model, inputs): + """ + Args: + model (nn.Module): + inputs (Any): inputs of the given model. Does not have to be tuple of tensors. + """ + wrapper = TracingAdapter(model, inputs, allow_non_tensor=True) + super().__init__(wrapper, wrapper.flattened_inputs) + self.set_op_handle(**{k: None for k in _IGNORED_OPS}) + + +def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]: + """ + Implement operator-level flops counting using jit. + This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard + detection models in detectron2. + Please use :class:`FlopCountAnalysis` for more advanced functionalities. + + Note: + The function runs the input through the model to compute flops. + The flops of a detection model is often input-dependent, for example, + the flops of box & mask head depends on the number of proposals & + the number of detected objects. + Therefore, the flops counting using a single input may not accurately + reflect the computation cost of a model. It's recommended to average + across a number of inputs. + + Args: + model: a detectron2 model that takes `list[dict]` as input. + inputs (list[dict]): inputs to model, in detectron2's standard format. + Only "image" key will be used. + supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count` + + Returns: + Counter: Gflop count per operator + """ + old_train = model.training + model.eval() + ret = FlopCountAnalysis(model, inputs).by_operator() + model.train(old_train) + return {k: v / 1e9 for k, v in ret.items()} + + +def activation_count_operators( + model: nn.Module, inputs: list, **kwargs +) -> typing.DefaultDict[str, float]: + """ + Implement operator-level activations counting using jit. + This is a wrapper of fvcore.nn.activation_count, that supports standard detection models + in detectron2. + + Note: + The function runs the input through the model to compute activations. + The activations of a detection model is often input-dependent, for example, + the activations of box & mask head depends on the number of proposals & + the number of detected objects. + + Args: + model: a detectron2 model that takes `list[dict]` as input. + inputs (list[dict]): inputs to model, in detectron2's standard format. + Only "image" key will be used. + + Returns: + Counter: activation count per operator + """ + return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs) + + +def _wrapper_count_operators( + model: nn.Module, inputs: list, mode: str, **kwargs +) -> typing.DefaultDict[str, float]: + # ignore some ops + supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS} + supported_ops.update(kwargs.pop("supported_ops", {})) + kwargs["supported_ops"] = supported_ops + + assert len(inputs) == 1, "Please use batch size=1" + tensor_input = inputs[0]["image"] + inputs = [{"image": tensor_input}] # remove other keys, in case there are any + + old_train = model.training + if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): + model = model.module + wrapper = TracingAdapter(model, inputs) + wrapper.eval() + if mode == FLOPS_MODE: + ret = flop_count(wrapper, (tensor_input,), **kwargs) + elif mode == ACTIVATIONS_MODE: + ret = activation_count(wrapper, (tensor_input,), **kwargs) + else: + raise NotImplementedError("Count for mode {} is not supported yet.".format(mode)) + # compatible with change in fvcore + if isinstance(ret, tuple): + ret = ret[0] + model.train(old_train) + return ret + + +def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]: + """ + Given a model, find parameters that do not contribute + to the loss. + + Args: + model: a model in training mode that returns losses + inputs: argument or a tuple of arguments. Inputs of the model + + Returns: + list[str]: the name of unused parameters + """ + assert model.training + for _, prm in model.named_parameters(): + prm.grad = None + + if isinstance(inputs, tuple): + losses = model(*inputs) + else: + losses = model(inputs) + + if isinstance(losses, dict): + losses = sum(losses.values()) + losses.backward() + + unused: List[str] = [] + for name, prm in model.named_parameters(): + if prm.grad is None: + unused.append(name) + prm.grad = None + return unused diff --git a/data_processing/detectron2/detectron2/utils/collect_env.py b/data_processing/detectron2/detectron2/utils/collect_env.py new file mode 100644 index 0000000..2846d7a --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/collect_env.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import numpy as np +import os +import re +import subprocess +import sys +from collections import defaultdict +import PIL +import torch +import torchvision +from tabulate import tabulate + +__all__ = ["collect_env_info"] + + +def collect_torch_env(): + try: + import torch.__config__ + + return torch.__config__.show() + except ImportError: + # compatible with older versions of pytorch + from torch.utils.collect_env import get_pretty_env_info + + return get_pretty_env_info() + + +def get_env_module(): + var_name = "DETECTRON2_ENV_MODULE" + return var_name, os.environ.get(var_name, "") + + +def detect_compute_compatibility(CUDA_HOME, so_file): + try: + cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump") + if os.path.isfile(cuobjdump): + output = subprocess.check_output( + "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True + ) + output = output.decode("utf-8").strip().split("\n") + arch = [] + for line in output: + line = re.findall(r"\.sm_([0-9]*)\.", line)[0] + arch.append(".".join(line)) + arch = sorted(set(arch)) + return ", ".join(arch) + else: + return so_file + "; cannot find cuobjdump" + except Exception: + # unhandled failure + return so_file + + +def collect_env_info(): + has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM + torch_version = torch.__version__ + + # NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional + from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME + + has_rocm = False + if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None): + has_rocm = True + has_cuda = has_gpu and (not has_rocm) + + data = [] + data.append(("sys.platform", sys.platform)) # check-template.yml depends on it + data.append(("Python", sys.version.replace("\n", ""))) + data.append(("numpy", np.__version__)) + + try: + import detectron2 # noqa + + data.append( + ("detectron2", detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__)) + ) + except ImportError: + data.append(("detectron2", "failed to import")) + except AttributeError: + data.append(("detectron2", "imported a wrong installation")) + + try: + import detectron2._C as _C + except ImportError as e: + data.append(("detectron2._C", f"not built correctly: {e}")) + + # print system compilers when extension fails to build + if sys.platform != "win32": # don't know what to do for windows + try: + # this is how torch/utils/cpp_extensions.py choose compiler + cxx = os.environ.get("CXX", "c++") + cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True) + cxx = cxx.decode("utf-8").strip().split("\n")[0] + except subprocess.SubprocessError: + cxx = "Not found" + data.append(("Compiler ($CXX)", cxx)) + + if has_cuda and CUDA_HOME is not None: + try: + nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") + nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True) + nvcc = nvcc.decode("utf-8").strip().split("\n")[-1] + except subprocess.SubprocessError: + nvcc = "Not found" + data.append(("CUDA compiler", nvcc)) + if has_cuda and sys.platform != "win32": + try: + so_file = importlib.util.find_spec("detectron2._C").origin + except (ImportError, AttributeError): + pass + else: + data.append( + ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, so_file)) + ) + else: + # print compilers that are used to build extension + data.append(("Compiler", _C.get_compiler_version())) + data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip + if has_cuda and getattr(_C, "has_cuda", lambda: True)(): + data.append( + ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__)) + ) + + data.append(get_env_module()) + data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__))) + data.append(("PyTorch debug build", torch.version.debug)) + try: + data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI)) + except Exception: + pass + + if not has_gpu: + has_gpu_text = "No: torch.cuda.is_available() == False" + else: + has_gpu_text = "Yes" + data.append(("GPU available", has_gpu_text)) + if has_gpu: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k))) + name = torch.cuda.get_device_name(k) + f" (arch={cap})" + devices[name].append(str(k)) + for name, devids in devices.items(): + data.append(("GPU " + ",".join(devids), name)) + + if has_rocm: + msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else "" + data.append(("ROCM_HOME", str(ROCM_HOME) + msg)) + else: + try: + from torch.utils.collect_env import get_nvidia_driver_version, run as _run + + data.append(("Driver version", get_nvidia_driver_version(_run))) + except Exception: + pass + msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else "" + data.append(("CUDA_HOME", str(CUDA_HOME) + msg)) + + cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if cuda_arch_list: + data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list)) + data.append(("Pillow", PIL.__version__)) + + try: + data.append( + ( + "torchvision", + str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__), + ) + ) + if has_cuda: + try: + torchvision_C = importlib.util.find_spec("torchvision._C").origin + msg = detect_compute_compatibility(CUDA_HOME, torchvision_C) + data.append(("torchvision arch flags", msg)) + except (ImportError, AttributeError): + data.append(("torchvision._C", "Not found")) + except AttributeError: + data.append(("torchvision", "unknown")) + + try: + import fvcore + + data.append(("fvcore", fvcore.__version__)) + except (ImportError, AttributeError): + pass + + try: + import iopath + + data.append(("iopath", iopath.__version__)) + except (ImportError, AttributeError): + pass + + try: + import cv2 + + data.append(("cv2", cv2.__version__)) + except (ImportError, AttributeError): + data.append(("cv2", "Not found")) + env_str = tabulate(data) + "\n" + env_str += collect_torch_env() + return env_str + + +def test_nccl_ops(): + num_gpu = torch.cuda.device_count() + if os.access("/tmp", os.W_OK): + import torch.multiprocessing as mp + + dist_url = "file:///tmp/nccl_tmp_file" + print("Testing NCCL connectivity ... this should not hang.") + mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False) + print("NCCL succeeded.") + + +def _test_nccl_worker(rank, num_gpu, dist_url): + import torch.distributed as dist + + dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu) + dist.barrier(device_ids=[rank]) + + +if __name__ == "__main__": + try: + from detectron2.utils.collect_env import collect_env_info as f + + print(f()) + except ImportError: + print(collect_env_info()) + + if torch.cuda.is_available(): + num_gpu = torch.cuda.device_count() + for k in range(num_gpu): + device = f"cuda:{k}" + try: + x = torch.tensor([1, 2.0], dtype=torch.float32) + x = x.to(device) + except Exception as e: + print( + f"Unable to copy tensor to device={device}: {e}. " + "Your CUDA environment is broken." + ) + if num_gpu > 1: + test_nccl_ops() diff --git a/data_processing/detectron2/detectron2/utils/colormap.py b/data_processing/detectron2/detectron2/utils/colormap.py new file mode 100644 index 0000000..14ded16 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/colormap.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +An awesome colormap for really neat visualizations. +Copied from Detectron, and removed gray colors. +""" + +import numpy as np +import random + +__all__ = ["colormap", "random_color", "random_colors"] + +# fmt: off +# RGB: +_COLORS = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.857, 0.857, 0.857, + 1.000, 1.000, 1.000 + ] +).astype(np.float32).reshape(-1, 3) +# fmt: on + + +def colormap(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] + """ + assert maximum in [255, 1], maximum + c = _COLORS * maximum + if not rgb: + c = c[:, ::-1] + return c + + +def random_color(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a vector of 3 numbers + """ + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not rgb: + ret = ret[::-1] + return ret + + +def random_colors(N, rgb=False, maximum=255): + """ + Args: + N (int): number of unique colors needed + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a list of random_color + """ + indices = random.sample(range(len(_COLORS)), N) + ret = [_COLORS[i] * maximum for i in indices] + if not rgb: + ret = [x[::-1] for x in ret] + return ret + + +if __name__ == "__main__": + import cv2 + + size = 100 + H, W = 10, 10 + canvas = np.random.rand(H * size, W * size, 3).astype("float32") + for h in range(H): + for w in range(W): + idx = h * W + w + if idx >= len(_COLORS): + break + canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] + cv2.imshow("a", canvas) + cv2.waitKey(0) diff --git a/data_processing/detectron2/detectron2/utils/comm.py b/data_processing/detectron2/detectron2/utils/comm.py new file mode 100644 index 0000000..a9ea9a9 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/comm.py @@ -0,0 +1,238 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import numpy as np +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +_MISSING_LOCAL_PG_ERROR = ( + "Local process group is not yet created! Please use detectron2's `launch()` " + "to start processes and initialize pytorch process group. If you need to start " + "processes in other ways, please call comm.create_local_process_group(" + "num_workers_per_machine) after calling torch.distributed.init_process_group()." +) + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +@functools.lru_cache() +def create_local_process_group(num_workers_per_machine: int) -> None: + """ + Create a process group that contains ranks within the same machine. + + Detectron2's launch() in engine/launch.py will call this function. If you start + workers without launch(), you'll have to also call this. Otherwise utilities + like `get_local_rank()` will not work. + + This function contains a barrier. All processes must call it together. + + Args: + num_workers_per_machine: the number of worker processes per machine. Typically + the number of GPUs. + """ + global _LOCAL_PROCESS_GROUP + assert _LOCAL_PROCESS_GROUP is None + assert get_world_size() % num_workers_per_machine == 0 + num_machines = get_world_size() // num_workers_per_machine + machine_rank = get_rank() // num_workers_per_machine + for i in range(num_machines): + ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + _LOCAL_PROCESS_GROUP = pg + + +def get_local_process_group(): + """ + Returns: + A torch process group which only includes processes that are on the same + machine as the current process. This group can be useful for communication + within a machine, e.g. a per-machine SyncBN. + """ + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return _LOCAL_PROCESS_GROUP + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + if dist.get_backend() == dist.Backend.NCCL: + # This argument is needed to avoid warnings. + # It's valid only for NCCL backend. + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + world_size = dist.get_world_size(group=group) + if world_size == 1: + return [data] + rank = dist.get_rank(group=group) + + if rank == dst: + output = [None for _ in range(world_size)] + dist.gather_object(data, output, dst=dst, group=group) + return output + else: + dist.gather_object(data, None, dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/data_processing/detectron2/detectron2/utils/develop.py b/data_processing/detectron2/detectron2/utils/develop.py new file mode 100644 index 0000000..e841698 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/develop.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" Utilities for developers only. +These are not visible to users (not automatically imported). And should not +appeared in docs.""" +# adapted from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py + + +def create_dummy_class(klass, dependency, message=""): + """ + When a dependency of a class is not available, create a dummy class which throws ImportError + when used. + + Args: + klass (str): name of the class. + dependency (str): name of the dependency. + message: extra message to print + Returns: + class: a class object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) + if message: + err = err + " " + message + + class _DummyMetaClass(type): + # throw error on class attribute access + def __getattr__(_, __): # noqa: B902 + raise ImportError(err) + + class _Dummy(object, metaclass=_DummyMetaClass): + # throw error on constructor + def __init__(self, *args, **kwargs): + raise ImportError(err) + + return _Dummy + + +def create_dummy_func(func, dependency, message=""): + """ + When a dependency of a function is not available, create a dummy function which throws + ImportError when used. + + Args: + func (str): name of the function. + dependency (str or list[str]): name(s) of the dependency. + message: extra message to print + Returns: + function: a function object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) + if message: + err = err + " " + message + + if isinstance(dependency, (list, tuple)): + dependency = ",".join(dependency) + + def _dummy(*args, **kwargs): + raise ImportError(err) + + return _dummy diff --git a/data_processing/detectron2/detectron2/utils/env.py b/data_processing/detectron2/detectron2/utils/env.py new file mode 100644 index 0000000..40634c1 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/env.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import importlib.util +import logging +import numpy as np +import os +import random +import sys +from datetime import datetime +import torch + +__all__ = ["seed_all_rng"] + + +TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) +""" +PyTorch version as a tuple of 2 ints. Useful for comparison. +""" + + +DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py +""" +Whether we're building documentation. +""" + + +def seed_all_rng(seed=None): + """ + Set the random seed for the RNG in torch, numpy and python. + + Args: + seed (int): if None, will use a strong random seed. + """ + if seed is None: + seed = ( + os.getpid() + + int(datetime.now().strftime("%S%f")) + + int.from_bytes(os.urandom(2), "big") + ) + logger = logging.getLogger(__name__) + logger.info("Using a generated random seed {}".format(seed)) + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + +# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path +def _import_file(module_name, file_path, make_importable=False): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if make_importable: + sys.modules[module_name] = module + return module + + +def _configure_libraries(): + """ + Configurations for some libraries. + """ + # An environment option to disable `import cv2` globally, + # in case it leads to negative performance impact + disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) + if disable_cv2: + sys.modules["cv2"] = None + else: + # Disable opencl in opencv since its interaction with cuda often has negative effects + # This envvar is supported after OpenCV 3.4.0 + os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" + try: + import cv2 + + if int(cv2.__version__.split(".")[0]) >= 3: + cv2.ocl.setUseOpenCL(False) + except ModuleNotFoundError: + # Other types of ImportError, if happened, should not be ignored. + # Because a failed opencv import could mess up address space + # https://github.com/skvark/opencv-python/issues/381 + pass + + def get_version(module, digit=2): + return tuple(map(int, module.__version__.split(".")[:digit])) + + # fmt: off + assert get_version(torch) >= (1, 4), "Requires torch>=1.4" + import fvcore + assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2" + import yaml + assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" + # fmt: on + + +_ENV_SETUP_DONE = False + + +def setup_environment(): + """Perform environment setup work. The default setup is a no-op, but this + function allows the user to specify a Python source file or a module in + the $DETECTRON2_ENV_MODULE environment variable, that performs + custom setup work that may be necessary to their computing environment. + """ + global _ENV_SETUP_DONE + if _ENV_SETUP_DONE: + return + _ENV_SETUP_DONE = True + + _configure_libraries() + + custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE") + + if custom_module_path: + setup_custom_environment(custom_module_path) + else: + # The default setup is a no-op + pass + + +def setup_custom_environment(custom_module): + """ + Load custom environment setup by importing a Python source file or a + module, and run the setup function. + """ + if custom_module.endswith(".py"): + module = _import_file("detectron2.utils.env.custom_module", custom_module) + else: + module = importlib.import_module(custom_module) + assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( + "Custom environment module defined in {} does not have the " + "required callable attribute 'setup_environment'." + ).format(custom_module) + module.setup_environment() + + +def fixup_module_metadata(module_name, namespace, keys=None): + """ + Fix the __qualname__ of module members to be their exported api name, so + when they are referenced in docs, sphinx can find them. Reference: + https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 + """ + if not DOC_BUILDING: + return + seen_ids = set() + + def fix_one(qualname, name, obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualitied + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + if keys is None: + keys = namespace.keys() + for objname in keys: + if not objname.startswith("_"): + obj = namespace[objname] + fix_one(objname, objname, obj) diff --git a/data_processing/detectron2/detectron2/utils/events.py b/data_processing/detectron2/detectron2/utils/events.py new file mode 100644 index 0000000..7d582a9 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/events.py @@ -0,0 +1,551 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import datetime +import json +import logging +import os +import time +from collections import defaultdict +from contextlib import contextmanager +from functools import cached_property +from typing import Optional +import torch +from fvcore.common.history_buffer import HistoryBuffer + +from detectron2.utils.file_io import PathManager + +__all__ = [ + "get_event_storage", + "has_event_storage", + "JSONWriter", + "TensorboardXWriter", + "CommonMetricPrinter", + "EventStorage", +] + +_CURRENT_STORAGE_STACK = [] + + +def get_event_storage(): + """ + Returns: + The :class:`EventStorage` object that's currently being used. + Throws an error if no :class:`EventStorage` is currently enabled. + """ + assert len( + _CURRENT_STORAGE_STACK + ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!" + return _CURRENT_STORAGE_STACK[-1] + + +def has_event_storage(): + """ + Returns: + Check if there are EventStorage() context existed. + """ + return len(_CURRENT_STORAGE_STACK) > 0 + + +class EventWriter: + """ + Base class for writers that obtain events from :class:`EventStorage` and process them. + """ + + def write(self): + raise NotImplementedError + + def close(self): + pass + + +class JSONWriter(EventWriter): + """ + Write scalars to a json file. + + It saves scalars as one json per line (instead of a big json) for easy parsing. + + Examples parsing such a json file: + :: + $ cat metrics.json | jq -s '.[0:2]' + [ + { + "data_time": 0.008433341979980469, + "iteration": 19, + "loss": 1.9228371381759644, + "loss_box_reg": 0.050025828182697296, + "loss_classifier": 0.5316952466964722, + "loss_mask": 0.7236229181289673, + "loss_rpn_box": 0.0856662318110466, + "loss_rpn_cls": 0.48198649287223816, + "lr": 0.007173333333333333, + "time": 0.25401854515075684 + }, + { + "data_time": 0.007216215133666992, + "iteration": 39, + "loss": 1.282649278640747, + "loss_box_reg": 0.06222952902317047, + "loss_classifier": 0.30682939291000366, + "loss_mask": 0.6970193982124329, + "loss_rpn_box": 0.038663312792778015, + "loss_rpn_cls": 0.1471673548221588, + "lr": 0.007706666666666667, + "time": 0.2490077018737793 + } + ] + + $ cat metrics.json | jq '.loss_mask' + 0.7126231789588928 + 0.689423680305481 + 0.6776131987571716 + ... + + """ + + def __init__(self, json_file, window_size=20): + """ + Args: + json_file (str): path to the json file. New data will be appended if the file exists. + window_size (int): the window size of median smoothing for the scalars whose + `smoothing_hint` are True. + """ + self._file_handle = PathManager.open(json_file, "a") + self._window_size = window_size + self._last_write = -1 + + def write(self): + storage = get_event_storage() + to_save = defaultdict(dict) + + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + # keep scalars that have not been written + if iter <= self._last_write: + continue + to_save[iter][k] = v + if len(to_save): + all_iters = sorted(to_save.keys()) + self._last_write = max(all_iters) + + for itr, scalars_per_iter in to_save.items(): + scalars_per_iter["iteration"] = itr + self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n") + self._file_handle.flush() + try: + os.fsync(self._file_handle.fileno()) + except AttributeError: + pass + + def close(self): + self._file_handle.close() + + +class TensorboardXWriter(EventWriter): + """ + Write all scalars to a tensorboard file. + """ + + def __init__(self, log_dir: str, window_size: int = 20, **kwargs): + """ + Args: + log_dir (str): the directory to save the output events + window_size (int): the scalars will be median-smoothed by this window size + + kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` + """ + self._window_size = window_size + self._writer_args = {"log_dir": log_dir, **kwargs} + self._last_write = -1 + + @cached_property + def _writer(self): + from torch.utils.tensorboard import SummaryWriter + + return SummaryWriter(**self._writer_args) + + def write(self): + storage = get_event_storage() + new_last_write = self._last_write + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + if iter > self._last_write: + self._writer.add_scalar(k, v, iter) + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write + + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + for img_name, img, step_num in storage._vis_data: + self._writer.add_image(img_name, img, step_num) + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + + if len(storage._histograms) >= 1: + for params in storage._histograms: + self._writer.add_histogram_raw(**params) + storage.clear_histograms() + + def close(self): + if "_writer" in self.__dict__: + self._writer.close() + + +class CommonMetricPrinter(EventWriter): + """ + Print **common** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, max_iter: Optional[int] = None, window_size: int = 20): + """ + Args: + max_iter: the maximum number of iterations to train. + Used to compute ETA. If not given, ETA will not be printed. + window_size (int): the losses will be median-smoothed by this window size + """ + self.logger = logging.getLogger("detectron2.utils.events") + self._max_iter = max_iter + self._window_size = window_size + self._last_write = None # (step, time) of last call to write(). Used to compute ETA + + def _get_eta(self, storage) -> Optional[str]: + if self._max_iter is None: + return "" + iteration = storage.iter + try: + eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1) + storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) + return str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + # estimate eta on our own - more noisy + eta_string = None + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + return eta_string + + def write(self): + storage = get_event_storage() + iteration = storage.iter + if iteration == self._max_iter: + # This hook only reports training progress (loss, ETA, etc) but not other data, + # therefore do not write anything after training succeeds, even if this method + # is called. + return + + try: + avg_data_time = storage.history("data_time").avg( + storage.count_samples("data_time", self._window_size) + ) + last_data_time = storage.history("data_time").latest() + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + avg_data_time = None + last_data_time = None + try: + avg_iter_time = storage.history("time").global_avg() + last_iter_time = storage.history("time").latest() + except KeyError: + avg_iter_time = None + last_iter_time = None + try: + lr = "{:.5g}".format(storage.history("lr").latest()) + except KeyError: + lr = "N/A" + + eta_string = self._get_eta(storage) + + if torch.cuda.is_available(): + max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + else: + max_mem_mb = None + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + self.logger.info( + str.format( + " {eta}iter: {iter} {losses} {non_losses} {avg_time}{last_time}" + + "{avg_data_time}{last_data_time} lr: {lr} {memory}", + eta=f"eta: {eta_string} " if eta_string else "", + iter=iteration, + losses=" ".join( + [ + "{}: {:.4g}".format( + k, v.median(storage.count_samples(k, self._window_size)) + ) + for k, v in storage.histories().items() + if "loss" in k + ] + ), + non_losses=" ".join( + [ + "{}: {:.4g}".format( + k, v.median(storage.count_samples(k, self._window_size)) + ) + for k, v in storage.histories().items() + if "[metric]" in k + ] + ), + avg_time="time: {:.4f} ".format(avg_iter_time) + if avg_iter_time is not None + else "", + last_time="last_time: {:.4f} ".format(last_iter_time) + if last_iter_time is not None + else "", + avg_data_time="data_time: {:.4f} ".format(avg_data_time) + if avg_data_time is not None + else "", + last_data_time="last_data_time: {:.4f} ".format(last_data_time) + if last_data_time is not None + else "", + lr=lr, + memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "", + ) + ) + + +class EventStorage: + """ + The user-facing class that provides metric storage functionalities. + + In the future we may add support for storing / logging other types of data if needed. + """ + + def __init__(self, start_iter=0): + """ + Args: + start_iter (int): the iteration number to start with + """ + self._history = defaultdict(HistoryBuffer) + self._smoothing_hints = {} + self._latest_scalars = {} + self._iter = start_iter + self._current_prefix = "" + self._vis_data = [] + self._histograms = [] + + def put_image(self, img_name, img_tensor): + """ + Add an `img_tensor` associated with `img_name`, to be shown on + tensorboard. + + Args: + img_name (str): The name of the image to put into tensorboard. + img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` + Tensor of shape `[channel, height, width]` where `channel` is + 3. The image format should be RGB. The elements in img_tensor + can either have values in [0, 1] (float32) or [0, 255] (uint8). + The `img_tensor` will be visualized in tensorboard. + """ + self._vis_data.append((img_name, img_tensor, self._iter)) + + def put_scalar(self, name, value, smoothing_hint=True, cur_iter=None): + """ + Add a scalar `value` to the `HistoryBuffer` associated with `name`. + + Args: + smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be + smoothed when logged. The hint will be accessible through + :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint + and apply custom smoothing rule. + + It defaults to True because most scalars we save need to be smoothed to + provide any useful signal. + cur_iter (int): an iteration number to set explicitly instead of current iteration + """ + name = self._current_prefix + name + cur_iter = self._iter if cur_iter is None else cur_iter + history = self._history[name] + value = float(value) + history.update(value, cur_iter) + self._latest_scalars[name] = (value, cur_iter) + + existing_hint = self._smoothing_hints.get(name) + + if existing_hint is not None: + assert ( + existing_hint == smoothing_hint + ), "Scalar {} was put with a different smoothing_hint!".format(name) + else: + self._smoothing_hints[name] = smoothing_hint + + def put_scalars(self, *, smoothing_hint=True, cur_iter=None, **kwargs): + """ + Put multiple scalars from keyword arguments. + + Examples: + + storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) + """ + for k, v in kwargs.items(): + self.put_scalar(k, v, smoothing_hint=smoothing_hint, cur_iter=cur_iter) + + def put_histogram(self, hist_name, hist_tensor, bins=1000): + """ + Create a histogram from a tensor. + + Args: + hist_name (str): The name of the histogram to put into tensorboard. + hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + into a histogram. + bins (int): Number of histogram bins. + """ + ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + + # Create a histogram with PyTorch + hist_counts = torch.histc(hist_tensor, bins=bins) + hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + + # Parameter for the add_histogram_raw function of SummaryWriter + hist_params = dict( + tag=hist_name, + min=ht_min, + max=ht_max, + num=len(hist_tensor), + sum=float(hist_tensor.sum()), + sum_squares=float(torch.sum(hist_tensor**2)), + bucket_limits=hist_edges[1:].tolist(), + bucket_counts=hist_counts.tolist(), + global_step=self._iter, + ) + self._histograms.append(hist_params) + + def history(self, name): + """ + Returns: + HistoryBuffer: the scalar history for name + """ + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + return ret + + def histories(self): + """ + Returns: + dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars + """ + return self._history + + def latest(self): + """ + Returns: + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. + """ + return self._latest_scalars + + def latest_with_smoothing_hint(self, window_size=20): + """ + Similar to :meth:`latest`, but the returned values + are either the un-smoothed original latest value, + or a median of the given window_size, + depend on whether the smoothing_hint is True. + + This provides a default behavior that other writers can use. + + Note: All scalars saved in the past `window_size` iterations are used for smoothing. + This is different from the `window_size` definition in HistoryBuffer. + Use :meth:`get_history_window_size` to get the `window_size` used in HistoryBuffer. + """ + result = {} + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(self.count_samples(k, window_size)) + if self._smoothing_hints[k] + else v, + itr, + ) + return result + + def count_samples(self, name, window_size=20): + """ + Return the number of samples logged in the past `window_size` iterations. + """ + samples = 0 + data = self._history[name].values() + for _, iter_ in reversed(data): + if iter_ > data[-1][1] - window_size: + samples += 1 + else: + break + return samples + + def smoothing_hints(self): + """ + Returns: + dict[name -> bool]: the user-provided hint on whether the scalar + is noisy and needs smoothing. + """ + return self._smoothing_hints + + def step(self): + """ + User should either: (1) Call this function to increment storage.iter when needed. Or + (2) Set `storage.iter` to the correct iteration number before each iteration. + + The storage will then be able to associate the new data with an iteration number. + """ + self._iter += 1 + + @property + def iter(self): + """ + Returns: + int: The current iteration number. When used together with a trainer, + this is ensured to be the same as trainer.iter. + """ + return self._iter + + @iter.setter + def iter(self, val): + self._iter = int(val) + + @property + def iteration(self): + # for backward compatibility + return self._iter + + def __enter__(self): + _CURRENT_STORAGE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert _CURRENT_STORAGE_STACK[-1] == self + _CURRENT_STORAGE_STACK.pop() + + @contextmanager + def name_scope(self, name): + """ + Yields: + A context within which all the events added to this storage + will be prefixed by the name scope. + """ + old_prefix = self._current_prefix + self._current_prefix = name.rstrip("/") + "/" + yield + self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] diff --git a/data_processing/detectron2/detectron2/utils/file_io.py b/data_processing/detectron2/detectron2/utils/file_io.py new file mode 100644 index 0000000..09f7dff --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/file_io.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler +from iopath.common.file_io import PathManager as PathManagerBase + +__all__ = ["PathManager", "PathHandler"] + + +PathManager = PathManagerBase() +""" +This is a detectron2 project-specific PathManager. +We try to stay away from global PathManager in fvcore as it +introduces potential conflicts among other libraries. +""" + + +class Detectron2Handler(PathHandler): + """ + Resolve anything that's hosted under detectron2's namespace. + """ + + PREFIX = "detectron2://" + S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" + + def _get_supported_prefixes(self): + return [self.PREFIX] + + def _get_local_path(self, path, **kwargs): + name = path[len(self.PREFIX) :] + return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs) + + def _open(self, path, mode="r", **kwargs): + return PathManager.open( + self.S3_DETECTRON2_PREFIX + path[len(self.PREFIX) :], mode, **kwargs + ) + + +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) +PathManager.register_handler(Detectron2Handler()) diff --git a/data_processing/detectron2/detectron2/utils/logger.py b/data_processing/detectron2/detectron2/utils/logger.py new file mode 100644 index 0000000..85be03c --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/logger.py @@ -0,0 +1,261 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import atexit +import functools +import logging +import os +import sys +import time +from collections import Counter +import torch +from tabulate import tabulate +from termcolor import colored + +from detectron2.utils.file_io import PathManager + +__all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"] + +D2_LOG_BUFFER_SIZE_KEY: str = "D2_LOG_BUFFER_SIZE" + +DEFAULT_LOG_BUFFER_SIZE: int = 1024 * 1024 # 1MB + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers +def setup_logger( + output=None, + distributed_rank=0, + *, + color=True, + name="detectron2", + abbrev_name=None, + enable_propagation: bool = False, + configure_stdout: bool = True +): + """ + Initialize the detectron2 logger and set its verbosity level to "DEBUG". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + abbrev_name (str): an abbreviation of the module, to avoid long names in logs. + Set to "" to not log the root module in logs. + By default, will abbreviate "detectron2" to "d2" and leave other + modules unchanged. + enable_propagation (bool): whether to propagate logs to the parent logger. + configure_stdout (bool): whether to configure logging to stdout. + + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = enable_propagation + + if abbrev_name is None: + abbrev_name = "d2" if name == "detectron2" else name + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if configure_stdout and distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + ".rank{}".format(distributed_rank) + PathManager.mkdirs(os.path.dirname(filename)) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # use 1K buffer if writing to cloud storage + io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename)) + atexit.register(io.close) + return io + + +def _get_log_stream_buffer_size(filename: str) -> int: + if "://" not in filename: + # Local file, no extra caching is necessary + return -1 + # Remote file requires a larger cache to avoid many small writes. + if D2_LOG_BUFFER_SIZE_KEY in os.environ: + return int(os.environ[D2_LOG_BUFFER_SIZE_KEY]) + return DEFAULT_LOG_BUFFER_SIZE + + +""" +Below are some other convenient logging methods. +They are mainly adopted from +https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py +""" + + +def _find_caller(): + """ + Returns: + str: module name of the caller + tuple: a hashable key to be used to identify different callers + """ + frame = sys._getframe(2) + while frame: + code = frame.f_code + if os.path.join("utils", "logger.") not in code.co_filename: + mod_name = frame.f_globals["__name__"] + if mod_name == "__main__": + mod_name = "detectron2" + return mod_name, (code.co_filename, frame.f_lineno, code.co_name) + frame = frame.f_back + + +_LOG_COUNTER = Counter() +_LOG_TIMER = {} + + +def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): + """ + Log only for the first n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + key (str or tuple[str]): the string(s) can be one of "caller" or + "message", which defines how to identify duplicated logs. + For example, if called with `n=1, key="caller"`, this function + will only log the first call from the same caller, regardless of + the message content. + If called with `n=1, key="message"`, this function will log the + same content only once, even if they are called from different places. + If called with `n=1, key=("caller", "message")`, this function + will not log only if the same caller has logged the same message before. + """ + if isinstance(key, str): + key = (key,) + assert len(key) > 0 + + caller_module, caller_key = _find_caller() + hash_key = () + if "caller" in key: + hash_key = hash_key + caller_key + if "message" in key: + hash_key = hash_key + (msg,) + + _LOG_COUNTER[hash_key] += 1 + if _LOG_COUNTER[hash_key] <= n: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n(lvl, msg, n=1, *, name=None): + """ + Log once per n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + _LOG_COUNTER[key] += 1 + if n == 1 or _LOG_COUNTER[key] % n == 1: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n_seconds(lvl, msg, n=1, *, name=None): + """ + Log no more than once per n seconds. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + last_logged = _LOG_TIMER.get(key, None) + current_time = time.time() + if last_logged is None or current_time - last_logged >= n: + logging.getLogger(name or caller_module).log(lvl, msg) + _LOG_TIMER[key] = current_time + + +def create_small_table(small_dict): + """ + Create a small table using the keys of small_dict as headers. This is only + suitable for small dictionaries. + + Args: + small_dict (dict): a result dictionary of only a few items. + + Returns: + str: the table as a string. + """ + keys, values = tuple(zip(*small_dict.items())) + table = tabulate( + [values], + headers=keys, + tablefmt="pipe", + floatfmt=".3f", + stralign="center", + numalign="center", + ) + return table + + +def _log_api_usage(identifier: str): + """ + Internal function used to log the usage of different detectron2 components + inside facebook's infra. + """ + torch._C._log_api_usage_once("detectron2." + identifier) diff --git a/data_processing/detectron2/detectron2/utils/memory.py b/data_processing/detectron2/detectron2/utils/memory.py new file mode 100644 index 0000000..bd49478 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/memory.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from contextlib import contextmanager +from functools import wraps +import torch + +__all__ = ["retry_if_cuda_oom"] + + +@contextmanager +def _ignore_torch_cuda_oom(): + """ + A context which ignores CUDA OOM exception from pytorch. + """ + try: + yield + except RuntimeError as e: + # NOTE: the string may change? + if "CUDA out of memory. " in str(e): + pass + else: + raise + + +def retry_if_cuda_oom(func): + """ + Makes a function retry itself after encountering + pytorch's CUDA OOM error. + It will first retry after calling `torch.cuda.empty_cache()`. + + If that still fails, it will then retry by trying to convert inputs to CPUs. + In this case, it expects the function to dispatch to CPU implementation. + The return values may become CPU tensors as well and it's user's + responsibility to convert it back to CUDA tensor if needed. + + Args: + func: a stateless callable that takes tensor-like objects as arguments + + Returns: + a callable which retries `func` if OOM is encountered. + + Examples: + :: + output = retry_if_cuda_oom(some_torch_function)(input1, input2) + # output may be on CPU even if inputs are on GPU + + Note: + 1. When converting inputs to CPU, it will only look at each argument and check + if it has `.device` and `.to` for conversion. Nested structures of tensors + are not supported. + + 2. Since the function might be called more than once, it has to be + stateless. + """ + + def maybe_to_cpu(x): + try: + like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to") + except AttributeError: + like_gpu_tensor = False + if like_gpu_tensor: + return x.to(device="cpu") + else: + return x + + @wraps(func) + def wrapped(*args, **kwargs): + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Clear cache and retry + torch.cuda.empty_cache() + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Try on CPU. This slows down the code significantly, therefore print a notice. + logger = logging.getLogger(__name__) + logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))) + new_args = (maybe_to_cpu(x) for x in args) + new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + return wrapped diff --git a/data_processing/detectron2/detectron2/utils/registry.py b/data_processing/detectron2/detectron2/utils/registry.py new file mode 100644 index 0000000..4b01e90 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/registry.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any +import pydoc +from fvcore.common.registry import Registry # for backward compatibility. + +""" +``Registry`` and `locate` provide ways to map a string (typically found +in config files) to callable objects. +""" + +__all__ = ["Registry", "locate"] + + +def _convert_target_to_string(t: Any) -> str: + """ + Inverse of ``locate()``. + + Args: + t: any object with ``__module__`` and ``__qualname__`` + """ + module, qualname = t.__module__, t.__qualname__ + + # Compress the path to this object, e.g. ``module.submodule._impl.class`` + # may become ``module.submodule.class``, if the later also resolves to the same + # object. This simplifies the string, and also is less affected by moving the + # class implementation. + module_parts = module.split(".") + for k in range(1, len(module_parts)): + prefix = ".".join(module_parts[:k]) + candidate = f"{prefix}.{qualname}" + try: + if locate(candidate) is t: + return candidate + except ImportError: + pass + return f"{module}.{qualname}" + + +def locate(name: str) -> Any: + """ + Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, + such as "module.submodule.class_name". + + Raise Exception if it cannot be found. + """ + obj = pydoc.locate(name) + + # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly + # by pydoc.locate. Try a private function from hydra. + if obj is None: + try: + # from hydra.utils import get_method - will print many errors + from hydra.utils import _locate + except ImportError as e: + raise ImportError(f"Cannot dynamically locate object {name}!") from e + else: + obj = _locate(name) # it raises if fails + + return obj diff --git a/data_processing/detectron2/detectron2/utils/serialize.py b/data_processing/detectron2/detectron2/utils/serialize.py new file mode 100644 index 0000000..0b38862 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/serialize.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import cloudpickle + + +class PicklableWrapper(object): + """ + Wrap an object to make it more picklable, note that it uses + heavy weight serialization libraries that are slower than pickle. + It's best to use it only on closures (which are usually not picklable). + + This is a simplified version of + https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py + """ + + def __init__(self, obj): + while isinstance(obj, PicklableWrapper): + # Wrapping an object twice is no-op + obj = obj._obj + self._obj = obj + + def __reduce__(self): + s = cloudpickle.dumps(self._obj) + return cloudpickle.loads, (s,) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __getattr__(self, attr): + # Ensure that the wrapped object can be used seamlessly as the previous object. + if attr not in ["_obj"]: + return getattr(self._obj, attr) + return getattr(self, attr) diff --git a/data_processing/detectron2/detectron2/utils/testing.py b/data_processing/detectron2/detectron2/utils/testing.py new file mode 100644 index 0000000..3f5b9db --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/testing.py @@ -0,0 +1,478 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import io +import numpy as np +import os +import re +import tempfile +import unittest +from typing import Callable +import torch +import torch.onnx.symbolic_helper as sym_help +from packaging import version +from torch._C import ListType +from torch.onnx import register_custom_op_symbolic + +from detectron2 import model_zoo +from detectron2.config import CfgNode, LazyConfig, instantiate +from detectron2.data import DatasetCatalog +from detectron2.data.detection_utils import read_image +from detectron2.modeling import build_model +from detectron2.structures import Boxes, Instances, ROIMasks +from detectron2.utils.file_io import PathManager + + +""" +Internal utilities for tests. Don't use except for writing tests. +""" + + +def get_model_no_weights(config_path): + """ + Like model_zoo.get, but do not load any weights (even pretrained) + """ + cfg = model_zoo.get_config(config_path) + if isinstance(cfg, CfgNode): + if not torch.cuda.is_available(): + cfg.MODEL.DEVICE = "cpu" + return build_model(cfg) + else: + return instantiate(cfg.model) + + +def random_boxes(num_boxes, max_coord=100, device="cpu"): + """ + Create a random Nx4 boxes tensor, with coordinates < max_coord. + """ + boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5) + boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression + # Note: the implementation of this function in torchvision is: + # boxes[:, 2:] += torch.rand(N, 2) * 100 + # but it does not guarantee non-negative widths/heights constraints: + # boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]: + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def get_sample_coco_image(tensor=True): + """ + Args: + tensor (bool): if True, returns 3xHxW tensor. + else, returns a HxWx3 numpy array. + + Returns: + an image, in BGR color. + """ + try: + file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"] + if not PathManager.exists(file_name): + raise FileNotFoundError() + except IOError: + # for public CI to run + file_name = PathManager.get_local_path( + "http://images.cocodataset.org/train2017/000000000009.jpg" + ) + ret = read_image(file_name, format="BGR") + if tensor: + ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1))) + return ret + + +def convert_scripted_instances(instances): + """ + Convert a scripted Instances object to a regular :class:`Instances` object + """ + assert hasattr( + instances, "image_size" + ), f"Expect an Instances object, but got {type(instances)}!" + ret = Instances(instances.image_size) + for name in instances._field_names: + val = getattr(instances, "_" + name, None) + if val is not None: + ret.set(name, val) + return ret + + +def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False): + """ + Args: + input, other (Instances): + size_as_tensor: compare image_size of the Instances as tensors (instead of tuples). + Useful for comparing outputs of tracing. + """ + if not isinstance(input, Instances): + input = convert_scripted_instances(input) + if not isinstance(other, Instances): + other = convert_scripted_instances(other) + + if not msg: + msg = "Two Instances are different! " + else: + msg = msg.rstrip() + " " + + size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!" + if size_as_tensor: + assert torch.equal( + torch.tensor(input.image_size), torch.tensor(other.image_size) + ), size_error_msg + else: + assert input.image_size == other.image_size, size_error_msg + fields = sorted(input.get_fields().keys()) + fields_other = sorted(other.get_fields().keys()) + assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!" + + for f in fields: + val1, val2 = input.get(f), other.get(f) + if isinstance(val1, (Boxes, ROIMasks)): + # boxes in the range of O(100) and can have a larger tolerance + assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), ( + msg + f"Field {f} differs too much!" + ) + elif isinstance(val1, torch.Tensor): + if val1.dtype.is_floating_point: + mag = torch.abs(val1).max().cpu().item() + assert torch.allclose(val1, val2, atol=mag * rtol), ( + msg + f"Field {f} differs too much!" + ) + else: + assert torch.equal(val1, val2), msg + f"Field {f} is different!" + else: + raise ValueError(f"Don't know how to compare type {type(val1)}") + + +def reload_script_model(module): + """ + Save a jit module and load it back. + Similar to the `getExportImportCopy` function in torch/testing/ + """ + buffer = io.BytesIO() + torch.jit.save(module, buffer) + buffer.seek(0) + return torch.jit.load(buffer) + + +def reload_lazy_config(cfg): + """ + Save an object by LazyConfig.save and load it back. + This is used to test that a config still works the same after + serialization/deserialization. + """ + with tempfile.TemporaryDirectory(prefix="detectron2") as d: + fname = os.path.join(d, "d2_cfg_test.yaml") + LazyConfig.save(cfg, fname) + return LazyConfig.load(fname) + + +def min_torch_version(min_version: str) -> bool: + """ + Returns True when torch's version is at least `min_version`. + """ + try: + import torch + except ImportError: + return False + + installed_version = version.parse(torch.__version__.split("+")[0]) + min_version = version.parse(min_version) + return installed_version >= min_version + + +def has_dynamic_axes(onnx_model): + """ + Return True when all ONNX input/output have only dynamic axes for all ranks + """ + return all( + not dim.dim_param.isnumeric() + for inp in onnx_model.graph.input + for dim in inp.type.tensor_type.shape.dim + ) and all( + not dim.dim_param.isnumeric() + for out in onnx_model.graph.output + for dim in out.type.tensor_type.shape.dim + ) + + +def register_custom_op_onnx_export( + opname: str, symbolic_fn: Callable, opset_version: int, min_version: str +) -> None: + """ + Register `symbolic_fn` as PyTorch's symbolic `opname`-`opset_version` for ONNX export. + The registration is performed only when current PyTorch's version is < `min_version.` + IMPORTANT: symbolic must be manually unregistered after the caller function returns + """ + if min_torch_version(min_version): + return + register_custom_op_symbolic(opname, symbolic_fn, opset_version) + print(f"_register_custom_op_onnx_export({opname}, {opset_version}) succeeded.") + + +def unregister_custom_op_onnx_export(opname: str, opset_version: int, min_version: str) -> None: + """ + Unregister PyTorch's symbolic `opname`-`opset_version` for ONNX export. + The un-registration is performed only when PyTorch's version is < `min_version` + IMPORTANT: The symbolic must have been manually registered by the caller, otherwise + the incorrect symbolic may be unregistered instead. + """ + + # TODO: _unregister_custom_op_symbolic is introduced PyTorch>=1.10 + # Remove after PyTorch 1.10+ is used by ALL detectron2's CI + try: + from torch.onnx import unregister_custom_op_symbolic as _unregister_custom_op_symbolic + except ImportError: + + def _unregister_custom_op_symbolic(symbolic_name, opset_version): + import torch.onnx.symbolic_registry as sym_registry + from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets + + def _get_ns_op_name_from_custom_op(symbolic_name): + try: + from torch.onnx.utils import get_ns_op_name_from_custom_op + + ns, op_name = get_ns_op_name_from_custom_op(symbolic_name) + except ImportError as import_error: + if not bool( + re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name) + ): + raise ValueError( + f"Invalid symbolic name {symbolic_name}. Must be `domain::name`" + ) from import_error + + ns, op_name = symbolic_name.split("::") + if ns == "onnx": + raise ValueError(f"{ns} domain cannot be modified.") from import_error + + if ns == "aten": + ns = "" + + return ns, op_name + + def _unregister_op(opname: str, domain: str, version: int): + try: + sym_registry.unregister_op(op_name, ns, ver) + except AttributeError as attribute_error: + if sym_registry.is_registered_op(opname, domain, version): + del sym_registry._registry[(domain, version)][opname] + if not sym_registry._registry[(domain, version)]: + del sym_registry._registry[(domain, version)] + else: + raise RuntimeError( + f"The opname {opname} is not registered." + ) from attribute_error + + ns, op_name = _get_ns_op_name_from_custom_op(symbolic_name) + for ver in _onnx_stable_opsets + [_onnx_main_opset]: + if ver >= opset_version: + _unregister_op(op_name, ns, ver) + + if min_torch_version(min_version): + return + _unregister_custom_op_symbolic(opname, opset_version) + print(f"_unregister_custom_op_onnx_export({opname}, {opset_version}) succeeded.") + + +skipIfOnCPUCI = unittest.skipIf( + os.environ.get("CI") and not torch.cuda.is_available(), + "The test is too slow on CPUs and will be executed on CircleCI's GPU jobs.", +) + + +def skipIfUnsupportedMinOpsetVersion(min_opset_version, current_opset_version=None): + """ + Skips tests for ONNX Opset versions older than min_opset_version. + """ + + def skip_dec(func): + def wrapper(self): + try: + opset_version = self.opset_version + except AttributeError: + opset_version = current_opset_version + if opset_version < min_opset_version: + raise unittest.SkipTest( + f"Unsupported opset_version {opset_version}" + f", required is {min_opset_version}" + ) + return func(self) + + return wrapper + + return skip_dec + + +def skipIfUnsupportedMinTorchVersion(min_version): + """ + Skips tests for PyTorch versions older than min_version. + """ + reason = f"module 'torch' has __version__ {torch.__version__}" f", required is: {min_version}" + return unittest.skipIf(not min_torch_version(min_version), reason) + + +# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI +def _pytorch1111_symbolic_opset9_to(g, self, *args): + """aten::to() symbolic that must be used for testing with PyTorch < 1.11.1.""" + + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(ListType.ofInts()) + or ( + sym_help._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + and isinstance(args[0].node()["value"], str) + ) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = sym_help._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = sym_help._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to can be onnx::Constant[Tensor] + # In this case, the constant value is a tensor not int, + # so sym_help._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant": + tval = args[0].node()["value"] + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = args[0].type().scalarType() + return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype]) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = sym_help._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) + dtype = sym_help._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) + dtype = sym_help._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + else: + return sym_help._onnx_unsupported("Unknown aten::to signature") + + +# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI +def _pytorch1111_symbolic_opset9_repeat_interleave(g, self, repeats, dim=None, output_size=None): + + # from torch.onnx.symbolic_helper import ScalarType + from torch.onnx.symbolic_opset9 import expand, unsqueeze + + input = self + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if sym_help._is_none(dim): + input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + dim = 0 + else: + dim = sym_help._maybe_get_scalar(dim) + + repeats_dim = sym_help._get_tensor_rank(repeats) + repeats_sizes = sym_help._get_tensor_sizes(repeats) + input_sizes = sym_help._get_tensor_sizes(input) + if repeats_dim is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank." + ) + if repeats_sizes is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "repeats size." + ) + if input_sizes is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "input size." + ) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if not sym_help._is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if input_sizes[dim] == 0: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + ) + else: + reps = input_sizes[dim] + repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + ) + if repeats_sizes[0] is None: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats" + ) + assert ( + repeats_sizes[0] == input_sizes[dim] + ), "repeats must have the same size as input along dim" + reps = repeats_sizes[0] + else: + raise RuntimeError("repeats must be 0-dim or 1-dim tensor") + + final_splits = list() + r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0) + if isinstance(r_splits, torch._C.Value): + r_splits = [r_splits] + i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim) + if isinstance(i_splits, torch._C.Value): + i_splits = [i_splits] + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = sym_help._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) diff --git a/data_processing/detectron2/detectron2/utils/tracing.py b/data_processing/detectron2/detectron2/utils/tracing.py new file mode 100644 index 0000000..577df4e --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/tracing.py @@ -0,0 +1,71 @@ +import inspect +import torch + +from detectron2.utils.env import TORCH_VERSION + +try: + from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + + tracing_current_exists = True +except ImportError: + tracing_current_exists = False + +try: + from torch.fx._symbolic_trace import _orig_module_call + + tracing_legacy_exists = True +except ImportError: + tracing_legacy_exists = False + + +@torch.jit.ignore +def is_fx_tracing_legacy() -> bool: + """ + Returns a bool indicating whether torch.fx is currently symbolically tracing a module. + Can be useful for gating module logic that is incompatible with symbolic tracing. + """ + return torch.nn.Module.__call__ is not _orig_module_call + + +@torch.jit.ignore +def is_fx_tracing() -> bool: + """Returns whether execution is currently in + Torch FX tracing mode""" + if TORCH_VERSION >= (1, 10) and tracing_current_exists: + return is_fx_tracing_current() + elif tracing_legacy_exists: + return is_fx_tracing_legacy() + else: + # Can't find either current or legacy tracing indication code. + # Enabling this assert_fx_safe() call regardless of tracing status. + return False + + +@torch.jit.ignore +def assert_fx_safe(condition: bool, message: str) -> torch.Tensor: + """An FX-tracing safe version of assert. + Avoids erroneous type assertion triggering when types are masked inside + an fx.proxy.Proxy object during tracing. + Args: condition - either a boolean expression or a string representing + the condition to test. If this assert triggers an exception when tracing + due to dynamic control flow, try encasing the expression in quotation + marks and supplying it as a string.""" + # Must return a concrete tensor for compatibility with PyTorch <=1.8. + # If <=1.8 compatibility is not needed, return type can be converted to None + if not is_fx_tracing(): + try: + if isinstance(condition, str): + caller_frame = inspect.currentframe().f_back + torch._assert( + eval(condition, caller_frame.f_globals, caller_frame.f_locals), message + ) + return torch.ones(1) + else: + torch._assert(condition, message) + return torch.ones(1) + except torch.fx.proxy.TraceError as e: + print( + "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + + str(e) + ) + return torch.zeros(1) diff --git a/data_processing/detectron2/detectron2/utils/video_visualizer.py b/data_processing/detectron2/detectron2/utils/video_visualizer.py new file mode 100644 index 0000000..42685be --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/video_visualizer.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import List +import pycocotools.mask as mask_util + +from detectron2.structures import Instances +from detectron2.utils.visualizer import ( + ColorMode, + Visualizer, + _create_text_labels, + _PanopticPrediction, +) + +from .colormap import random_color, random_colors + + +class _DetectedInstance: + """ + Used to store data about detected objects in video frame, + in order to transfer color to objects in the future frames. + + Attributes: + label (int): + bbox (tuple[float]): + mask_rle (dict): + color (tuple[float]): RGB colors in range (0, 1) + ttl (int): time-to-live for the instance. For example, if ttl=2, + the instance color can be transferred to objects in the next two frames. + """ + + __slots__ = ["label", "bbox", "mask_rle", "color", "ttl"] + + def __init__(self, label, bbox, mask_rle, color, ttl): + self.label = label + self.bbox = bbox + self.mask_rle = mask_rle + self.color = color + self.ttl = ttl + + +class VideoVisualizer: + def __init__(self, metadata, instance_mode=ColorMode.IMAGE): + """ + Args: + metadata (MetadataCatalog): image metadata. + """ + self.metadata = metadata + self._old_instances = [] + assert instance_mode in [ + ColorMode.IMAGE, + ColorMode.IMAGE_BW, + ], "Other mode not supported yet." + self._instance_mode = instance_mode + self._max_num_instances = self.metadata.get("max_num_instances", 74) + self._assigned_colors = {} + self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1) + self._color_idx_set = set(range(len(self._color_pool))) + + def draw_instance_predictions(self, frame, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255]. + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + frame_visualizer = Visualizer(frame, self.metadata) + num_instances = len(predictions) + if num_instances == 0: + return frame_visualizer.output + + boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions) + periods = predictions.ID_period if predictions.has("ID_period") else None + period_threshold = self.metadata.get("period_threshold", 0) + visibilities = ( + [True] * len(predictions) + if periods is None + else [x > period_threshold for x in periods] + ) + + if predictions.has("pred_masks"): + masks = predictions.pred_masks + # mask IOU is not yet enabled + # masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F")) + # assert len(masks_rles) == num_instances + else: + masks = None + + if not predictions.has("COLOR"): + if predictions.has("ID"): + colors = self._assign_colors_by_id(predictions) + else: + # ToDo: clean old assign color method and use a default tracker to assign id + detected = [ + _DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8) + for i in range(num_instances) + ] + colors = self._assign_colors(detected) + + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + + if self._instance_mode == ColorMode.IMAGE_BW: + # any() returns uint8 tensor + frame_visualizer.output.reset_image( + frame_visualizer._create_grayscale_image( + (masks.any(dim=0) > 0).numpy() if masks is not None else None + ) + ) + alpha = 0.3 + else: + alpha = 0.5 + + labels = ( + None + if labels is None + else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))] + ) # noqa + assigned_colors = ( + None + if colors is None + else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))] + ) # noqa + frame_visualizer.overlay_instances( + boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting + masks=None if masks is None else masks[visibilities], + labels=labels, + keypoints=None if keypoints is None else keypoints[visibilities], + assigned_colors=assigned_colors, + alpha=alpha, + ) + + return frame_visualizer.output + + def draw_sem_seg(self, frame, sem_seg, area_threshold=None): + """ + Args: + sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W), + each value is the integer label. + area_threshold (Optional[int]): only draw segmentations larger than the threshold + """ + # don't need to do anything special + frame_visualizer = Visualizer(frame, self.metadata) + frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None) + return frame_visualizer.output + + def draw_panoptic_seg_predictions( + self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5 + ): + frame_visualizer = Visualizer(frame, self.metadata) + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + frame_visualizer.output.reset_image( + frame_visualizer._create_grayscale_image(pred.non_empty_mask()) + ) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + frame_visualizer.draw_binary_mask( + mask, + color=mask_color, + text=self.metadata.stuff_classes[category_idx], + alpha=alpha, + area_threshold=area_threshold, + ) + + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return frame_visualizer.output + # draw mask for all instances second + masks, sinfo = list(zip(*all_instances)) + num_instances = len(masks) + masks_rles = mask_util.encode( + np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F") + ) + assert len(masks_rles) == num_instances + + category_ids = [x["category_id"] for x in sinfo] + detected = [ + _DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8) + for i in range(num_instances) + ] + colors = self._assign_colors(detected) + labels = [self.metadata.thing_classes[k] for k in category_ids] + + frame_visualizer.overlay_instances( + boxes=None, + masks=masks, + labels=labels, + keypoints=None, + assigned_colors=colors, + alpha=alpha, + ) + return frame_visualizer.output + + def _assign_colors(self, instances): + """ + Naive tracking heuristics to assign same color to the same instance, + will update the internal state of tracked instances. + + Returns: + list[tuple[float]]: list of colors. + """ + + # Compute iou with either boxes or masks: + is_crowd = np.zeros((len(instances),), dtype=bool) + if instances[0].bbox is None: + assert instances[0].mask_rle is not None + # use mask iou only when box iou is None + # because box seems good enough + rles_old = [x.mask_rle for x in self._old_instances] + rles_new = [x.mask_rle for x in instances] + ious = mask_util.iou(rles_old, rles_new, is_crowd) + threshold = 0.5 + else: + boxes_old = [x.bbox for x in self._old_instances] + boxes_new = [x.bbox for x in instances] + ious = mask_util.iou(boxes_old, boxes_new, is_crowd) + threshold = 0.6 + if len(ious) == 0: + ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32") + + # Only allow matching instances of the same label: + for old_idx, old in enumerate(self._old_instances): + for new_idx, new in enumerate(instances): + if old.label != new.label: + ious[old_idx, new_idx] = 0 + + matched_new_per_old = np.asarray(ious).argmax(axis=1) + max_iou_per_old = np.asarray(ious).max(axis=1) + + # Try to find match for each old instance: + extra_instances = [] + for idx, inst in enumerate(self._old_instances): + if max_iou_per_old[idx] > threshold: + newidx = matched_new_per_old[idx] + if instances[newidx].color is None: + instances[newidx].color = inst.color + continue + # If an old instance does not match any new instances, + # keep it for the next frame in case it is just missed by the detector + inst.ttl -= 1 + if inst.ttl > 0: + extra_instances.append(inst) + + # Assign random color to newly-detected instances: + for inst in instances: + if inst.color is None: + inst.color = random_color(rgb=True, maximum=1) + self._old_instances = instances[:] + extra_instances + return [d.color for d in instances] + + def _assign_colors_by_id(self, instances: Instances) -> List: + colors = [] + untracked_ids = set(self._assigned_colors.keys()) + for id in instances.ID: + if id in self._assigned_colors: + colors.append(self._color_pool[self._assigned_colors[id]]) + untracked_ids.remove(id) + else: + assert ( + len(self._color_idx_set) >= 1 + ), f"Number of id exceeded maximum, \ + max = {self._max_num_instances}" + idx = self._color_idx_set.pop() + color = self._color_pool[idx] + self._assigned_colors[id] = idx + colors.append(color) + for id in untracked_ids: + self._color_idx_set.add(self._assigned_colors[id]) + del self._assigned_colors[id] + return colors diff --git a/data_processing/detectron2/detectron2/utils/visualizer.py b/data_processing/detectron2/detectron2/utils/visualizer.py new file mode 100644 index 0000000..5d2cc17 --- /dev/null +++ b/data_processing/detectron2/detectron2/utils/visualizer.py @@ -0,0 +1,1267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import colorsys +import logging +import math +import numpy as np +from enum import Enum, unique +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import pycocotools.mask as mask_util +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image + +from detectron2.data import MetadataCatalog +from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes +from detectron2.utils.file_io import PathManager + +from .colormap import random_color + +logger = logging.getLogger(__name__) + +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = False # if original format is polygon, does not have holes + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] + if labels is not None and is_crowd is not None: + labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + # TODO implement a fast, rasterized version using OpenCV + + def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + if metadata is None: + metadata = MetadataCatalog.get("__nonexist__") + self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 90, 10 // scale + ) + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes + ] + alpha = 0.8 + else: + colors = None + alpha = 0.5 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = self.metadata.stuff_classes[category_idx] + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + labels = _create_text_labels( + category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo] + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def draw_dataset_dict(self, dic): + """ + Draw annotations/segmentations in Detectron2 Dataset format. + + Args: + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + + Returns: + output (VisImage): image object with visualizations. + """ + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] + else: + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) + else: + keypts = None + + boxes = [ + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], + ) + self.overlay_instances( + labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5) + return self.output + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + alpha=0.5, + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + if areas is not None: + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] if boxes is not None else None + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], edge_color=color) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) + + if labels is not None: + # first get a box + if boxes is not None: + x0, y0, x1, y1 = boxes[i] + text_pos = (x0, y0) # if drawing boxes, put text on the box corner. + horiz_align = "left" + elif masks is not None: + # skip small mask without polygon + if len(masks[i].polygons) == 0: + continue + + x0, y0, x1, y1 = masks[i].bbox() + + # draw text in the center (defined by median) when box is not drawn + # median is less sensitive to outliers. + text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] + horiz_align = "center" + else: + continue # drawing the box confidence for keypoints isn't very useful. + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + labels[i], + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + return self.output + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) + return self.output + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 4, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] + # x: left->right ; y: top->down + rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size + ) + self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10 + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is None: + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) + return tuple(np.clip(modified_color, 0.0, 1.0)) + + def _convert_boxes(self, boxes): + """ + Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. + """ + if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): + return boxes.tensor.detach().numpy() + else: + return np.asarray(boxes) + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, PolygonMasks): + m = m.polygons + if isinstance(m, BitMasks): + m = m.tensor.numpy() + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + self.draw_text(text, center, color=color) + + def _convert_keypoints(self, keypoints): + if isinstance(keypoints, Keypoints): + keypoints = keypoints.tensor + keypoints = np.asarray(keypoints) + return keypoints + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output diff --git a/data_processing/detectron2/dev/README.md b/data_processing/detectron2/dev/README.md new file mode 100644 index 0000000..bec811a --- /dev/null +++ b/data_processing/detectron2/dev/README.md @@ -0,0 +1,7 @@ + +## Some scripts for developers to use, include: + +- `linter.sh`: lint the codebase before commit. +- `run_{inference,instant}_tests.sh`: run inference/training for a few iterations. + Note that these tests require 2 GPUs. +- `parse_results.sh`: parse results from a log file. diff --git a/data_processing/detectron2/dev/linter.sh b/data_processing/detectron2/dev/linter.sh new file mode 100644 index 0000000..55793e0 --- /dev/null +++ b/data_processing/detectron2/dev/linter.sh @@ -0,0 +1,42 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +# cd to detectron2 project root +cd "$(dirname "${BASH_SOURCE[0]}")/.." + +{ + black --version | grep -E "22\." > /dev/null +} || { + echo "Linter requires 'black==22.*' !" + exit 1 +} + +ISORT_VERSION=$(isort --version-number) +if [[ "$ISORT_VERSION" != 4.3* ]]; then + echo "Linter requires isort==4.3.21 !" + exit 1 +fi + +set -v + +echo "Running isort ..." +isort -y -sp . --atomic + +echo "Running black ..." +black -l 100 . + +echo "Running flake8 ..." +if [ -x "$(command -v flake8)" ]; then + flake8 . +else + python3 -m flake8 . +fi + +# echo "Running mypy ..." +# Pytorch does not have enough type annotations +# mypy detectron2/solver detectron2/structures detectron2/config + +echo "Running clang-format ..." +find . -regex ".*\.\(cpp\|c\|cc\|cu\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 clang-format -i + +command -v arc > /dev/null && arc lint diff --git a/data_processing/detectron2/dev/packaging/README.md b/data_processing/detectron2/dev/packaging/README.md new file mode 100644 index 0000000..0174b7d --- /dev/null +++ b/data_processing/detectron2/dev/packaging/README.md @@ -0,0 +1,17 @@ + +## To build a cu101 wheel for release: + +``` +$ nvidia-docker run -it --storage-opt "size=20GB" --name pt pytorch/manylinux-cuda101 +# inside the container: +# git clone https://github.com/facebookresearch/detectron2/ +# cd detectron2 +# export CU_VERSION=cu101 D2_VERSION_SUFFIX= PYTHON_VERSION=3.7 PYTORCH_VERSION=1.8 +# ./dev/packaging/build_wheel.sh +``` + +## To build all wheels for combinations of CUDA and Python +``` +./dev/packaging/build_all_wheels.sh +./dev/packaging/gen_wheel_index.sh /path/to/wheels +``` diff --git a/data_processing/detectron2/dev/packaging/build_all_wheels.sh b/data_processing/detectron2/dev/packaging/build_all_wheels.sh new file mode 100644 index 0000000..00f9de5 --- /dev/null +++ b/data_processing/detectron2/dev/packaging/build_all_wheels.sh @@ -0,0 +1,65 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +[[ -d "dev/packaging" ]] || { + echo "Please run this script at detectron2 root!" + exit 1 +} + +build_one() { + cu=$1 + pytorch_ver=$2 + + case "$cu" in + cu*) + container_name=manylinux-cuda${cu/cu/} + ;; + cpu) + container_name=manylinux-cuda101 + ;; + *) + echo "Unrecognized cu=$cu" + exit 1 + ;; + esac + + echo "Launching container $container_name ..." + container_id="$container_name"_"$cu"_"$pytorch_ver" + + py_versions=(3.7 3.8 3.9) + + for py in "${py_versions[@]}"; do + docker run -itd \ + --name "$container_id" \ + --mount type=bind,source="$(pwd)",target=/detectron2 \ + pytorch/$container_name + + cat </dev/null 2>&1 && pwd )" +. "$script_dir/pkg_helpers.bash" + +echo "Build Settings:" +echo "CU_VERSION: $CU_VERSION" # e.g. cu101 +echo "D2_VERSION_SUFFIX: $D2_VERSION_SUFFIX" # e.g. +cu101 or "" +echo "PYTHON_VERSION: $PYTHON_VERSION" # e.g. 3.7 +echo "PYTORCH_VERSION: $PYTORCH_VERSION" # e.g. 1.4 + +setup_cuda +setup_wheel_python + +yum install ninja-build -y +ln -sv /usr/bin/ninja-build /usr/bin/ninja || true + +pip_install pip numpy -U +pip_install "torch==$PYTORCH_VERSION" \ + -f https://download.pytorch.org/whl/"$CU_VERSION"/torch_stable.html + +# use separate directories to allow parallel build +BASE_BUILD_DIR=build/$CU_VERSION-py$PYTHON_VERSION-pt$PYTORCH_VERSION +python setup.py \ + build -b "$BASE_BUILD_DIR" \ + bdist_wheel -b "$BASE_BUILD_DIR/build_dist" -d "wheels/$CU_VERSION/torch$PYTORCH_VERSION" +rm -rf "$BASE_BUILD_DIR" diff --git a/data_processing/detectron2/dev/packaging/gen_install_table.py b/data_processing/detectron2/dev/packaging/gen_install_table.py new file mode 100644 index 0000000..b4c852d --- /dev/null +++ b/data_processing/detectron2/dev/packaging/gen_install_table.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# -*- coding: utf-8 -*- + +import argparse + +template = """
install
\
+python -m pip install detectron2{d2_version} -f \\
+  https://dl.fbaipublicfiles.com/detectron2/wheels/{cuda}/torch{torch}/index.html
+
""" +CUDA_SUFFIX = { + "11.3": "cu113", + "11.1": "cu111", + "11.0": "cu110", + "10.2": "cu102", + "10.1": "cu101", + "10.0": "cu100", + "9.2": "cu92", + "cpu": "cpu", +} + + +def gen_header(torch_versions): + return '' + "".join( + [ + ''.format(t) + for t in torch_versions + ] + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--d2-version", help="detectron2 version number, default to empty") + args = parser.parse_args() + d2_version = f"=={args.d2_version}" if args.d2_version else "" + + all_versions = ( + [("1.8", k) for k in ["11.1", "10.2", "10.1", "cpu"]] + + [("1.9", k) for k in ["11.1", "10.2", "cpu"]] + + [("1.10", k) for k in ["11.3", "11.1", "10.2", "cpu"]] + ) + + torch_versions = sorted( + {k[0] for k in all_versions}, key=lambda x: int(x.split(".")[1]), reverse=True + ) + cuda_versions = sorted( + {k[1] for k in all_versions}, key=lambda x: float(x) if x != "cpu" else 0, reverse=True + ) + + table = gen_header(torch_versions) + for cu in cuda_versions: + table += f""" """ + cu_suffix = CUDA_SUFFIX[cu] + for torch in torch_versions: + if (torch, cu) in all_versions: + cell = template.format(d2_version=d2_version, cuda=cu_suffix, torch=torch) + else: + cell = "" + table += f""" """ + table += "" + table += "
CUDA torch {}
{cu}{cell}
" + print(table) diff --git a/data_processing/detectron2/dev/packaging/gen_wheel_index.sh b/data_processing/detectron2/dev/packaging/gen_wheel_index.sh new file mode 100644 index 0000000..ec96a27 --- /dev/null +++ b/data_processing/detectron2/dev/packaging/gen_wheel_index.sh @@ -0,0 +1,46 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + + +root=$(readlink -f $1) +if [[ -z "$root" ]]; then + echo "Usage: ./gen_wheel_index.sh /absolute/path/to/wheels" + exit +fi + +export LC_ALL=C # reproducible sort +# NOTE: all sort in this script might not work when xx.10 is released + +index=$root/index.html + +cd "$root" +for cu in cpu cu92 cu100 cu101 cu102 cu110 cu111 cu113; do + mkdir -p "$root/$cu" + cd "$root/$cu" + echo "Creating $PWD/index.html ..." + # First sort by torch version, then stable sort by d2 version with unique. + # As a result, the latest torch version for each d2 version is kept. + for whl in $(find -type f -name '*.whl' -printf '%P\n' \ + | sort -k 1 -r | sort -t '/' -k 2 --stable -r --unique); do + echo "$whl
" + done > index.html + + + for torch in torch*; do + cd "$root/$cu/$torch" + + # list all whl for each cuda,torch version + echo "Creating $PWD/index.html ..." + for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do + echo "$whl
" + done > index.html + done +done + +cd "$root" +# Just list everything: +echo "Creating $index ..." +for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do + echo "$whl
" +done > "$index" + diff --git a/data_processing/detectron2/dev/packaging/pkg_helpers.bash b/data_processing/detectron2/dev/packaging/pkg_helpers.bash new file mode 100644 index 0000000..550bb6e --- /dev/null +++ b/data_processing/detectron2/dev/packaging/pkg_helpers.bash @@ -0,0 +1,75 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +# Function to retry functions that sometimes timeout or have flaky failures +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} +# Install with pip a bit more robustly than the default +pip_install() { + retry pip install --progress-bar off "$@" +} + + +setup_cuda() { + # Now work out the CUDA settings + # Like other torch domain libraries, we choose common GPU architectures only. + # See https://github.com/pytorch/pytorch/blob/master/torch/utils/cpp_extension.py + # and https://github.com/pytorch/vision/blob/main/packaging/pkg_helpers.bash for reference. + export FORCE_CUDA=1 + case "$CU_VERSION" in + cu113) + export CUDA_HOME=/usr/local/cuda-11.3/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX;8.0;8.6+PTX" + ;; + cu112) + export CUDA_HOME=/usr/local/cuda-11.2/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX;8.0;8.6+PTX" + ;; + cu111) + export CUDA_HOME=/usr/local/cuda-11.1/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX;8.0;8.6+PTX" + ;; + cu110) + export CUDA_HOME=/usr/local/cuda-11.0/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX;8.0+PTX" + ;; + cu102) + export CUDA_HOME=/usr/local/cuda-10.2/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX" + ;; + cu101) + export CUDA_HOME=/usr/local/cuda-10.1/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX" + ;; + cu100) + export CUDA_HOME=/usr/local/cuda-10.0/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0;7.5+PTX" + ;; + cu92) + export CUDA_HOME=/usr/local/cuda-9.2/ + export TORCH_CUDA_ARCH_LIST="3.7;5.0;5.2;6.0;6.1+PTX;7.0+PTX" + ;; + cpu) + unset FORCE_CUDA + export CUDA_VISIBLE_DEVICES= + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac +} + +setup_wheel_python() { + case "$PYTHON_VERSION" in + 3.7) python_abi=cp37-cp37m ;; + 3.8) python_abi=cp38-cp38 ;; + 3.9) python_abi=cp39-cp39 ;; + *) + echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" + exit 1 + ;; + esac + export PATH="/opt/python/$python_abi/bin:$PATH" +} diff --git a/data_processing/detectron2/dev/parse_results.sh b/data_processing/detectron2/dev/parse_results.sh new file mode 100644 index 0000000..80768a4 --- /dev/null +++ b/data_processing/detectron2/dev/parse_results.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. + +# A shell script that parses metrics from the log file. +# Make it easier for developers to track performance of models. + +LOG="$1" + +if [[ -z "$LOG" ]]; then + echo "Usage: $0 /path/to/log/file" + exit 1 +fi + +# [12/15 11:47:32] trainer INFO: Total training time: 12:15:04.446477 (0.4900 s / it) +# [12/15 11:49:03] inference INFO: Total inference time: 0:01:25.326167 (0.13652186737060548 s / img per device, on 8 devices) +# [12/15 11:49:03] inference INFO: Total inference pure compute time: ..... + +# training time +trainspeed=$(grep -o 'Overall training.*' "$LOG" | grep -Eo '\(.*\)' | grep -o '[0-9\.]*') +echo "Training speed: $trainspeed s/it" + +# inference time: there could be multiple inference during training +inferencespeed=$(grep -o 'Total inference pure.*' "$LOG" | tail -n1 | grep -Eo '\(.*\)' | grep -o '[0-9\.]*' | head -n1) +echo "Inference speed: $inferencespeed s/it" + +# [12/15 11:47:18] trainer INFO: eta: 0:00:00 iter: 90000 loss: 0.5407 (0.7256) loss_classifier: 0.1744 (0.2446) loss_box_reg: 0.0838 (0.1160) loss_mask: 0.2159 (0.2722) loss_objectness: 0.0244 (0.0429) loss_rpn_box_reg: 0.0279 (0.0500) time: 0.4487 (0.4899) data: 0.0076 (0.0975) lr: 0.000200 max mem: 4161 +memory=$(grep -o 'max[_ ]mem: [0-9]*' "$LOG" | tail -n1 | grep -o '[0-9]*') +echo "Training memory: $memory MB" + +echo "Easy to copypaste:" +echo "$trainspeed","$inferencespeed","$memory" + +echo "------------------------------" + +# [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: bbox +# [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl +# [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0017,0.0024,0.0017,0.0005,0.0019,0.0011 +# [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: segm +# [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl +# [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0014,0.0021,0.0016,0.0005,0.0016,0.0011 + +echo "COCO Results:" +num_tasks=$(grep -o 'copypaste:.*Task.*' "$LOG" | sort -u | wc -l) +# each task has 3 lines +grep -o 'copypaste:.*' "$LOG" | cut -d ' ' -f 2- | tail -n $((num_tasks * 3)) diff --git a/data_processing/detectron2/dev/run_inference_tests.sh b/data_processing/detectron2/dev/run_inference_tests.sh new file mode 100644 index 0000000..bc9dcc5 --- /dev/null +++ b/data_processing/detectron2/dev/run_inference_tests.sh @@ -0,0 +1,44 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +BIN="python tools/train_net.py" +OUTPUT="inference_test_output" +NUM_GPUS=2 + +CFG_LIST=( "${@:1}" ) + +if [ ${#CFG_LIST[@]} -eq 0 ]; then + CFG_LIST=( ./configs/quick_schedules/*inference_acc_test.yaml ) +fi + +echo "========================================================================" +echo "Configs to run:" +echo "${CFG_LIST[@]}" +echo "========================================================================" + + +for cfg in "${CFG_LIST[@]}"; do + echo "========================================================================" + echo "Running $cfg ..." + echo "========================================================================" + $BIN \ + --eval-only \ + --num-gpus $NUM_GPUS \ + --config-file "$cfg" \ + OUTPUT_DIR $OUTPUT + rm -rf $OUTPUT +done + + +echo "========================================================================" +echo "Running demo.py ..." +echo "========================================================================" +DEMO_BIN="python demo/demo.py" +COCO_DIR=datasets/coco/val2014 +mkdir -pv $OUTPUT + +set -v + +$DEMO_BIN --config-file ./configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml \ + --input $COCO_DIR/COCO_val2014_0000001933* --output $OUTPUT +rm -rf $OUTPUT diff --git a/data_processing/detectron2/dev/run_instant_tests.sh b/data_processing/detectron2/dev/run_instant_tests.sh new file mode 100644 index 0000000..9fd9ba0 --- /dev/null +++ b/data_processing/detectron2/dev/run_instant_tests.sh @@ -0,0 +1,27 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +BIN="python tools/train_net.py" +OUTPUT="instant_test_output" +NUM_GPUS=2 + +CFG_LIST=( "${@:1}" ) +if [ ${#CFG_LIST[@]} -eq 0 ]; then + CFG_LIST=( ./configs/quick_schedules/*instant_test.yaml ) +fi + +echo "========================================================================" +echo "Configs to run:" +echo "${CFG_LIST[@]}" +echo "========================================================================" + +for cfg in "${CFG_LIST[@]}"; do + echo "========================================================================" + echo "Running $cfg ..." + echo "========================================================================" + $BIN --num-gpus $NUM_GPUS --config-file "$cfg" \ + SOLVER.IMS_PER_BATCH $(($NUM_GPUS * 2)) \ + OUTPUT_DIR "$OUTPUT" + rm -rf "$OUTPUT" +done + diff --git a/data_processing/detectron2/docker/Dockerfile b/data_processing/detectron2/docker/Dockerfile new file mode 100644 index 0000000..fae0060 --- /dev/null +++ b/data_processing/detectron2/docker/Dockerfile @@ -0,0 +1,47 @@ +FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04 +# use an older system (18.04) to avoid opencv incompatibility (issue#3524) + +ENV DEBIAN_FRONTEND noninteractive +RUN apt-get update && apt-get install -y \ + python3-opencv ca-certificates python3-dev git wget sudo ninja-build +RUN ln -sv /usr/bin/python3 /usr/bin/python + +# create a non-root user +ARG USER_ID=1000 +RUN useradd -m --no-log-init --system --uid ${USER_ID} appuser -g sudo +RUN echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers +USER appuser +WORKDIR /home/appuser + +ENV PATH="/home/appuser/.local/bin:${PATH}" +RUN wget https://bootstrap.pypa.io/pip/3.6/get-pip.py && \ + python3 get-pip.py --user && \ + rm get-pip.py + +# install dependencies +# See https://pytorch.org/ for other options if you use a different version of CUDA +RUN pip install --user tensorboard cmake onnx # cmake from apt-get is too old +RUN pip install --user torch==1.10 torchvision==0.11.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html + +RUN pip install --user 'git+https://github.com/facebookresearch/fvcore' +# install detectron2 +RUN git clone https://github.com/facebookresearch/detectron2 detectron2_repo +# set FORCE_CUDA because during `docker build` cuda is not accessible +ENV FORCE_CUDA="1" +# This will by default build detectron2 for all common cuda architectures and take a lot more time, +# because inside `docker build`, there is no way to tell which architecture will be used. +ARG TORCH_CUDA_ARCH_LIST="Kepler;Kepler+Tesla;Maxwell;Maxwell+Tegra;Pascal;Volta;Turing" +ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" + +RUN pip install --user -e detectron2_repo + +# Set a fixed model cache directory. +ENV FVCORE_CACHE="/tmp" +WORKDIR /home/appuser/detectron2_repo + +# run detectron2 under user "appuser": +# wget http://images.cocodataset.org/val2017/000000439715.jpg -O input.jpg +# python3 demo/demo.py \ + #--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + #--input input.jpg --output outputs/ \ + #--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl diff --git a/data_processing/detectron2/docker/README.md b/data_processing/detectron2/docker/README.md new file mode 100644 index 0000000..ea709f3 --- /dev/null +++ b/data_processing/detectron2/docker/README.md @@ -0,0 +1,45 @@ + +## Use the container (with docker ≥ 19.03) + +``` +cd docker/ +# Build: +docker build --build-arg USER_ID=$UID -t detectron2:v0 . +# Launch (require GPUs): +docker run --gpus all -it \ + --shm-size=8gb --env="DISPLAY" --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ + --name=detectron2 detectron2:v0 + +# Grant docker access to host X server to show images +xhost +local:`docker inspect --format='{{ .Config.Hostname }}' detectron2` +``` + +## Use the container (with docker-compose ≥ 1.28.0) + +Install docker-compose and nvidia-docker-toolkit, then run: +``` +cd docker && USER_ID=$UID docker-compose run detectron2 +``` + +## Use the deployment container (to test C++ examples) +After building the base detectron2 container as above, do: +``` +# Build: +docker build -t detectron2-deploy:v0 -f deploy.Dockerfile . +# Launch: +docker run --gpus all -it detectron2-deploy:v0 +``` + +#### Using a persistent cache directory + +You can prevent models from being re-downloaded on every run, +by storing them in a cache directory. + +To do this, add `--volume=$HOME/.torch/fvcore_cache:/tmp:rw` in the run command. + +## Install new dependencies +Add the following to `Dockerfile` to make persistent changes. +``` +RUN sudo apt-get update && sudo apt-get install -y vim +``` +Or run them in the container to make temporary changes. diff --git a/data_processing/detectron2/docker/deploy.Dockerfile b/data_processing/detectron2/docker/deploy.Dockerfile new file mode 100644 index 0000000..30b4ed7 --- /dev/null +++ b/data_processing/detectron2/docker/deploy.Dockerfile @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# This file defines a container that compiles the C++ examples of detectron2. +# See docker/README.md for usage. + +# Depends on the image produced by "./Dockerfile" +FROM detectron2:v0 + +USER appuser +ENV HOME=/home/appuser +WORKDIR $HOME + +# Let torchvision find libtorch +ENV CMAKE_PREFIX_PATH=$HOME/.local/lib/python3.6/site-packages/torch/ + +RUN sudo apt-get update && sudo apt-get install libopencv-dev --yes + +# install libtorchvision +RUN git clone --branch v0.11.1 https://github.com/pytorch/vision/ +RUN mkdir vision/build && cd vision/build && \ + cmake .. -DCMAKE_INSTALL_PREFIX=$HOME/.local -DCMAKE_BUILD_TYPE=Release -DWITH_CUDA=on -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST && \ + make -j && make install + +# make our installation take effect +ENV CPATH=$HOME/.local/include \ + LIBRARY_PATH=$HOME/.local/lib \ + LD_LIBRARY_PATH=$HOME/.local/lib + + +# build C++ examples of detectron2 +RUN cd detectron2_repo/tools/deploy && mkdir build && cd build && \ + cmake -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST .. && make +# binaries will be available under tools/deploy/build diff --git a/data_processing/detectron2/docker/docker-compose.yml b/data_processing/detectron2/docker/docker-compose.yml new file mode 100644 index 0000000..6665ab4 --- /dev/null +++ b/data_processing/detectron2/docker/docker-compose.yml @@ -0,0 +1,26 @@ +version: "2.3" +services: + detectron2: + build: + context: . + dockerfile: Dockerfile + args: + USER_ID: ${USER_ID:-1000} + deploy: + resources: + reservations: + devices: + - capabilities: + - gpu + shm_size: "8gb" + ulimits: + memlock: -1 + stack: 67108864 + volumes: + - /tmp/.X11-unix:/tmp/.X11-unix:ro + environment: + - DISPLAY=$DISPLAY + - NVIDIA_VISIBLE_DEVICES=all + # Uncomment with proper source to access webcam from docker + # devices: + # - /dev/video0:/dev/video0 diff --git a/data_processing/detectron2/docs/.gitignore b/data_processing/detectron2/docs/.gitignore new file mode 100644 index 0000000..e35d885 --- /dev/null +++ b/data_processing/detectron2/docs/.gitignore @@ -0,0 +1 @@ +_build diff --git a/data_processing/detectron2/docs/Makefile b/data_processing/detectron2/docs/Makefile new file mode 100644 index 0000000..718eddc --- /dev/null +++ b/data_processing/detectron2/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# Copyright (c) Facebook, Inc. and its affiliates. + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/data_processing/detectron2/docs/README.md b/data_processing/detectron2/docs/README.md new file mode 100644 index 0000000..8531caf --- /dev/null +++ b/data_processing/detectron2/docs/README.md @@ -0,0 +1,15 @@ +# Read the docs: + +The latest documentation built from this directory is available at [detectron2.readthedocs.io](https://detectron2.readthedocs.io/). +Documents in this directory are not meant to be read on github. + +# Build the docs: + +1. Install detectron2 according to [INSTALL.md](../INSTALL.md). +2. Install additional libraries required to build docs: + - docutils==0.16 + - Sphinx==3.2.0 + - recommonmark==0.6.0 + - sphinx_rtd_theme + +3. Run `make html` from this directory. diff --git a/data_processing/detectron2/docs/_static/css/custom.css b/data_processing/detectron2/docs/_static/css/custom.css new file mode 100644 index 0000000..6c51176 --- /dev/null +++ b/data_processing/detectron2/docs/_static/css/custom.css @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * some extra css to make markdown look similar between github/sphinx + */ + +/* + * Below is for install.md: + */ +.rst-content code { + white-space: pre; + border: 0px; +} + +.rst-content th { + border: 1px solid #e1e4e5; +} + +.rst-content th p { + /* otherwise will be default 24px for regular paragraph */ + margin-bottom: 0px; +} + +.rst-content .line-block { + /* otherwise will be 24px */ + margin-bottom: 0px; +} + +div.section > details { + padding-bottom: 1em; +} diff --git a/data_processing/detectron2/docs/conf.py b/data_processing/detectron2/docs/conf.py new file mode 100644 index 0000000..1fb3e30 --- /dev/null +++ b/data_processing/detectron2/docs/conf.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +# flake8: noqa + +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +from unittest import mock +from sphinx.domains import Domain +from typing import Dict, List, Tuple + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +import sphinx_rtd_theme + + +class GithubURLDomain(Domain): + """ + Resolve certain links in markdown files to github source. + """ + + name = "githuburl" + ROOT = "https://github.com/facebookresearch/detectron2/blob/main/" + LINKED_DOC = ["tutorials/install", "tutorials/getting_started"] + + def resolve_any_xref(self, env, fromdocname, builder, target, node, contnode): + github_url = None + if not target.endswith("html") and target.startswith("../../"): + url = target.replace("../", "") + github_url = url + if fromdocname in self.LINKED_DOC: + # unresolved links in these docs are all github links + github_url = target + + if github_url is not None: + if github_url.endswith("MODEL_ZOO") or github_url.endswith("README"): + # bug of recommonmark. + # https://github.com/readthedocs/recommonmark/blob/ddd56e7717e9745f11300059e4268e204138a6b1/recommonmark/parser.py#L152-L155 + github_url += ".md" + print("Ref {} resolved to github:{}".format(target, github_url)) + contnode["refuri"] = self.ROOT + github_url + return [("githuburl:any", contnode)] + else: + return [] + + +# to support markdown +from recommonmark.parser import CommonMarkParser + +sys.path.insert(0, os.path.abspath("../")) +os.environ["_DOC_BUILDING"] = "True" +DEPLOY = os.environ.get("READTHEDOCS") == "True" + + +# -- Project information ----------------------------------------------------- + +# fmt: off +try: + import torch # noqa +except ImportError: + for m in [ + "torch", "torchvision", "torch.nn", "torch.nn.parallel", "torch.distributed", "torch.multiprocessing", "torch.autograd", + "torch.autograd.function", "torch.nn.modules", "torch.nn.modules.utils", "torch.utils", "torch.utils.data", "torch.onnx", + "torchvision", "torchvision.ops", + ]: + sys.modules[m] = mock.Mock(name=m) + sys.modules['torch'].__version__ = "1.7" # fake version + HAS_TORCH = False +else: + try: + torch.ops.detectron2 = mock.Mock(name="torch.ops.detectron2") + except: + pass + HAS_TORCH = True + +for m in [ + "cv2", "scipy", "portalocker", "detectron2._C", + "pycocotools", "pycocotools.mask", "pycocotools.coco", "pycocotools.cocoeval", + "google", "google.protobuf", "google.protobuf.internal", "onnx", + "caffe2", "caffe2.proto", "caffe2.python", "caffe2.python.utils", "caffe2.python.onnx", "caffe2.python.onnx.backend", +]: + sys.modules[m] = mock.Mock(name=m) +# fmt: on +sys.modules["cv2"].__version__ = "3.4" + +import detectron2 # isort: skip + +if HAS_TORCH: + from detectron2.utils.env import fixup_module_metadata + + fixup_module_metadata("torch.nn", torch.nn.__dict__) + fixup_module_metadata("torch.utils.data", torch.utils.data.__dict__) + + +project = "detectron2" +copyright = "2019-2020, detectron2 contributors" +author = "detectron2 contributors" + +# The short X.Y version +version = detectron2.__version__ +# The full version, including alpha/beta/rc tags +release = version + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +needs_sphinx = "3.0" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "recommonmark", + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", +] + +# -- Configurations for plugins ------------ +napoleon_google_docstring = True +napoleon_include_init_with_doc = True +napoleon_include_special_with_doc = True +napoleon_numpy_docstring = False +napoleon_use_rtype = False +autodoc_inherit_docstrings = False +autodoc_member_order = "bysource" + +if DEPLOY: + intersphinx_timeout = 10 +else: + # skip this when building locally + intersphinx_timeout = 0.5 +intersphinx_mapping = { + "python": ("https://docs.python.org/3.7", None), + "numpy": ("https://docs.scipy.org/doc/numpy/", None), + "torch": ("https://pytorch.org/docs/master/", None), +} +# ------------------------- + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +source_suffix = [".rst", ".md"] + +# The master toctree document. +master_doc = "index" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "build", "README.md", "tutorials/README.md"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + + +# -- Options for HTML output ------------------------------------------------- + +html_theme = "sphinx_rtd_theme" +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] +html_css_files = ["css/custom.css"] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = "detectron2doc" + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, "detectron2.tex", "detectron2 Documentation", "detectron2 contributors", "manual") +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "detectron2", "detectron2 Documentation", [author], 1)] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "detectron2", + "detectron2 Documentation", + author, + "detectron2", + "One line description of project.", + "Miscellaneous", + ) +] + + +# -- Options for todo extension ---------------------------------------------- + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +def autodoc_skip_member(app, what, name, obj, skip, options): + # we hide something deliberately + if getattr(obj, "__HIDE_SPHINX_DOC__", False): + return True + + # Hide some that are deprecated or not intended to be used + HIDDEN = { + "ResNetBlockBase", + "GroupedBatchSampler", + "build_transform_gen", + "apply_transform_gens", + "TransformGen", + "apply_augmentations", + "StandardAugInput", + "build_batch_data_loader", + "draw_panoptic_seg_predictions", + "WarmupCosineLR", + "WarmupMultiStepLR", + "downgrade_config", + "upgrade_config", + "add_export_config", + } + try: + if name in HIDDEN or ( + hasattr(obj, "__doc__") and obj.__doc__.lower().strip().startswith("deprecated") + ): + print("Skipping deprecated object: {}".format(name)) + return True + except: + pass + return skip + + +_PAPER_DATA = { + "resnet": ("1512.03385", "Deep Residual Learning for Image Recognition"), + "fpn": ("1612.03144", "Feature Pyramid Networks for Object Detection"), + "mask r-cnn": ("1703.06870", "Mask R-CNN"), + "faster r-cnn": ( + "1506.01497", + "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks", + ), + "deformconv": ("1703.06211", "Deformable Convolutional Networks"), + "deformconv2": ("1811.11168", "Deformable ConvNets v2: More Deformable, Better Results"), + "panopticfpn": ("1901.02446", "Panoptic Feature Pyramid Networks"), + "retinanet": ("1708.02002", "Focal Loss for Dense Object Detection"), + "cascade r-cnn": ("1712.00726", "Cascade R-CNN: Delving into High Quality Object Detection"), + "lvis": ("1908.03195", "LVIS: A Dataset for Large Vocabulary Instance Segmentation"), + "rrpn": ("1703.01086", "Arbitrary-Oriented Scene Text Detection via Rotation Proposals"), + "imagenet in 1h": ("1706.02677", "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"), + "xception": ("1610.02357", "Xception: Deep Learning with Depthwise Separable Convolutions"), + "mobilenet": ( + "1704.04861", + "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications", + ), + "deeplabv3+": ( + "1802.02611", + "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation", + ), + "dds": ("2003.13678", "Designing Network Design Spaces"), + "scaling": ("2103.06877", "Fast and Accurate Model Scaling"), + "fcos": ("2006.09214", "FCOS: A Simple and Strong Anchor-free Object Detector"), + "rethinking-batchnorm": ("2105.07576", 'Rethinking "Batch" in BatchNorm'), + "vitdet": ("2203.16527", "Exploring Plain Vision Transformer Backbones for Object Detection"), + "mvitv2": ( + "2112.01526", + "MViTv2: Improved Multiscale Vision Transformers for Classification and Detection", + ), + "swin": ( + "2103.14030", + "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows", + ), + "omni3d": ( + "2207.10660", + "Omni3D: A Large Benchmark and Model for 3D Object Detection in the Wild", + ), +} + + +def paper_ref_role( + typ: str, + rawtext: str, + text: str, + lineno: int, + inliner, + options: Dict = {}, + content: List[str] = [], +): + """ + Parse :paper:`xxx`. Similar to the "extlinks" sphinx extension. + """ + from docutils import nodes, utils + from sphinx.util.nodes import split_explicit_title + + text = utils.unescape(text) + has_explicit_title, title, link = split_explicit_title(text) + link = link.lower() + if link not in _PAPER_DATA: + inliner.reporter.warning("Cannot find paper " + link) + paper_url, paper_title = "#", link + else: + paper_url, paper_title = _PAPER_DATA[link] + if "/" not in paper_url: + paper_url = "https://arxiv.org/abs/" + paper_url + if not has_explicit_title: + title = paper_title + pnode = nodes.reference(title, title, internal=False, refuri=paper_url) + return [pnode], [] + + +def setup(app): + from recommonmark.transform import AutoStructify + + app.add_domain(GithubURLDomain) + app.connect("autodoc-skip-member", autodoc_skip_member) + app.add_role("paper", paper_ref_role) + app.add_config_value( + "recommonmark_config", + {"enable_math": True, "enable_inline_math": True, "enable_eval_rst": True}, + True, + ) + app.add_transform(AutoStructify) diff --git a/data_processing/detectron2/docs/index.rst b/data_processing/detectron2/docs/index.rst new file mode 100644 index 0000000..8634b7b --- /dev/null +++ b/data_processing/detectron2/docs/index.rst @@ -0,0 +1,14 @@ +.. detectron2 documentation master file, created by + sphinx-quickstart on Sat Sep 21 13:46:45 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to detectron2's documentation! +====================================== + +.. toctree:: + :maxdepth: 2 + + tutorials/index + notes/index + modules/index diff --git a/data_processing/detectron2/docs/modules/checkpoint.rst b/data_processing/detectron2/docs/modules/checkpoint.rst new file mode 100644 index 0000000..449caaf --- /dev/null +++ b/data_processing/detectron2/docs/modules/checkpoint.rst @@ -0,0 +1,7 @@ +detectron2.checkpoint +============================= + +.. automodule:: detectron2.checkpoint + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/config.rst b/data_processing/detectron2/docs/modules/config.rst new file mode 100644 index 0000000..c76913d --- /dev/null +++ b/data_processing/detectron2/docs/modules/config.rst @@ -0,0 +1,18 @@ +detectron2.config +========================= + +Related tutorials: :doc:`../tutorials/configs`, :doc:`../tutorials/extend`. + +.. automodule:: detectron2.config + :members: + :undoc-members: + :show-inheritance: + + +Yaml Config References +----------------- + +.. literalinclude:: ../../detectron2/config/defaults.py + :language: python + :linenos: + :lines: 7- diff --git a/data_processing/detectron2/docs/modules/data.rst b/data_processing/detectron2/docs/modules/data.rst new file mode 100644 index 0000000..0d5bd89 --- /dev/null +++ b/data_processing/detectron2/docs/modules/data.rst @@ -0,0 +1,37 @@ +detectron2.data +======================= + +.. autodata:: detectron2.data.DatasetCatalog(dict) + :annotation: + +.. autodata:: detectron2.data.MetadataCatalog(dict) + :annotation: + +.. automodule:: detectron2.data + :members: + :undoc-members: + :show-inheritance: + +detectron2.data.detection\_utils module +--------------------------------------- + +.. automodule:: detectron2.data.detection_utils + :members: + :undoc-members: + :show-inheritance: + +detectron2.data.datasets module +--------------------------------------- + +.. automodule:: detectron2.data.datasets + :members: + :undoc-members: + :show-inheritance: + +detectron2.data.samplers module +--------------------------------------- + +.. automodule:: detectron2.data.samplers + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/data_transforms.rst b/data_processing/detectron2/docs/modules/data_transforms.rst new file mode 100644 index 0000000..1533a43 --- /dev/null +++ b/data_processing/detectron2/docs/modules/data_transforms.rst @@ -0,0 +1,10 @@ +detectron2.data.transforms +==================================== + +Related tutorial: :doc:`../tutorials/augmentation`. + +.. automodule:: detectron2.data.transforms + :members: + :undoc-members: + :show-inheritance: + :imported-members: diff --git a/data_processing/detectron2/docs/modules/engine.rst b/data_processing/detectron2/docs/modules/engine.rst new file mode 100644 index 0000000..7e0d2b0 --- /dev/null +++ b/data_processing/detectron2/docs/modules/engine.rst @@ -0,0 +1,26 @@ +detectron2.engine +========================= + +Related tutorial: :doc:`../tutorials/training`. + +.. automodule:: detectron2.engine + :members: + :undoc-members: + :show-inheritance: + + +detectron2.engine.defaults module +--------------------------------- + +.. automodule:: detectron2.engine.defaults + :members: + :undoc-members: + :show-inheritance: + +detectron2.engine.hooks module +--------------------------------- + +.. automodule:: detectron2.engine.hooks + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/evaluation.rst b/data_processing/detectron2/docs/modules/evaluation.rst new file mode 100644 index 0000000..69bfc4b --- /dev/null +++ b/data_processing/detectron2/docs/modules/evaluation.rst @@ -0,0 +1,7 @@ +detectron2.evaluation +============================= + +.. automodule:: detectron2.evaluation + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/export.rst b/data_processing/detectron2/docs/modules/export.rst new file mode 100644 index 0000000..dcee14f --- /dev/null +++ b/data_processing/detectron2/docs/modules/export.rst @@ -0,0 +1,9 @@ +detectron2.export +========================= + +Related tutorial: :doc:`../tutorials/deployment`. + +.. automodule:: detectron2.export + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/fvcore.rst b/data_processing/detectron2/docs/modules/fvcore.rst new file mode 100644 index 0000000..c8bf9f5 --- /dev/null +++ b/data_processing/detectron2/docs/modules/fvcore.rst @@ -0,0 +1,49 @@ +fvcore documentation +==================== + +Detectron2 depends on utilities in +`fvcore `_. +We include part of fvcore documentation here for easier reference. + +fvcore.nn +----------------- + +.. automodule:: fvcore.nn + :members: + :inherited-members: + :undoc-members: + :show-inheritance: + +fvcore.common +--------------------- + +.. automodule:: fvcore.common.checkpoint + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: fvcore.common.config + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: fvcore.common.history_buffer + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: fvcore.common.param_scheduler + :members: + :inherited-members: + :undoc-members: + :show-inheritance: + +.. automodule:: fvcore.common.registry + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: fvcore.common.timer + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/index.rst b/data_processing/detectron2/docs/modules/index.rst new file mode 100644 index 0000000..14b7543 --- /dev/null +++ b/data_processing/detectron2/docs/modules/index.rst @@ -0,0 +1,19 @@ +API Documentation +================== + +.. toctree:: + + checkpoint + config + data + data_transforms + engine + evaluation + layers + model_zoo + modeling + solver + structures + utils + export + fvcore diff --git a/data_processing/detectron2/docs/modules/layers.rst b/data_processing/detectron2/docs/modules/layers.rst new file mode 100644 index 0000000..b43b42a --- /dev/null +++ b/data_processing/detectron2/docs/modules/layers.rst @@ -0,0 +1,7 @@ +detectron2.layers +========================= + +.. automodule:: detectron2.layers + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/model_zoo.rst b/data_processing/detectron2/docs/modules/model_zoo.rst new file mode 100644 index 0000000..5abbad1 --- /dev/null +++ b/data_processing/detectron2/docs/modules/model_zoo.rst @@ -0,0 +1,7 @@ +detectron2.model_zoo +============================ + +.. automodule:: detectron2.model_zoo + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/modeling.rst b/data_processing/detectron2/docs/modules/modeling.rst new file mode 100644 index 0000000..a22c7ed --- /dev/null +++ b/data_processing/detectron2/docs/modules/modeling.rst @@ -0,0 +1,58 @@ +detectron2.modeling +=========================== + +.. automodule:: detectron2.modeling + :members: + :undoc-members: + :show-inheritance: + + +detectron2.modeling.poolers module +--------------------------------------- + +.. automodule:: detectron2.modeling.poolers + :members: + :undoc-members: + :show-inheritance: + + +detectron2.modeling.sampling module +------------------------------------ + +.. automodule:: detectron2.modeling.sampling + :members: + :undoc-members: + :show-inheritance: + + +detectron2.modeling.box_regression module +------------------------------------------ + +.. automodule:: detectron2.modeling.box_regression + :members: + :undoc-members: + :show-inheritance: + + +Model Registries +----------------- + +These are different registries provided in modeling. +Each registry provide you the ability to replace it with your customized component, +without having to modify detectron2's code. + +Note that it is impossible to allow users to customize any line of code directly. +Even just to add one line at some place, +you'll likely need to find out the smallest registry which contains that line, +and register your component to that registry. + + +.. autodata:: detectron2.modeling.META_ARCH_REGISTRY +.. autodata:: detectron2.modeling.BACKBONE_REGISTRY +.. autodata:: detectron2.modeling.PROPOSAL_GENERATOR_REGISTRY +.. autodata:: detectron2.modeling.RPN_HEAD_REGISTRY +.. autodata:: detectron2.modeling.ANCHOR_GENERATOR_REGISTRY +.. autodata:: detectron2.modeling.ROI_HEADS_REGISTRY +.. autodata:: detectron2.modeling.ROI_BOX_HEAD_REGISTRY +.. autodata:: detectron2.modeling.ROI_MASK_HEAD_REGISTRY +.. autodata:: detectron2.modeling.ROI_KEYPOINT_HEAD_REGISTRY diff --git a/data_processing/detectron2/docs/modules/solver.rst b/data_processing/detectron2/docs/modules/solver.rst new file mode 100644 index 0000000..59d98c7 --- /dev/null +++ b/data_processing/detectron2/docs/modules/solver.rst @@ -0,0 +1,7 @@ +detectron2.solver +========================= + +.. automodule:: detectron2.solver + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/structures.rst b/data_processing/detectron2/docs/modules/structures.rst new file mode 100644 index 0000000..1369dc0 --- /dev/null +++ b/data_processing/detectron2/docs/modules/structures.rst @@ -0,0 +1,7 @@ +detectron2.structures +============================= + +.. automodule:: detectron2.structures + :members: + :undoc-members: + :show-inheritance: diff --git a/data_processing/detectron2/docs/modules/utils.rst b/data_processing/detectron2/docs/modules/utils.rst new file mode 100644 index 0000000..ab58f2c --- /dev/null +++ b/data_processing/detectron2/docs/modules/utils.rst @@ -0,0 +1,80 @@ +detectron2.utils +======================== + +detectron2.utils.colormap module +-------------------------------- + +.. automodule:: detectron2.utils.colormap + :members: + :undoc-members: + :show-inheritance: + +detectron2.utils.comm module +---------------------------- + +.. automodule:: detectron2.utils.comm + :members: + :undoc-members: + :show-inheritance: + + +detectron2.utils.events module +------------------------------ + +.. automodule:: detectron2.utils.events + :members: + :undoc-members: + :show-inheritance: + + +detectron2.utils.logger module +------------------------------ + +.. automodule:: detectron2.utils.logger + :members: + :undoc-members: + :show-inheritance: + + +detectron2.utils.registry module +-------------------------------- + +.. automodule:: detectron2.utils.registry + :members: + :undoc-members: + :show-inheritance: + +detectron2.utils.memory module +---------------------------------- + +.. automodule:: detectron2.utils.memory + :members: + :undoc-members: + :show-inheritance: + + +detectron2.utils.analysis module +---------------------------------- + +.. automodule:: detectron2.utils.analysis + :members: + :undoc-members: + :show-inheritance: + + +detectron2.utils.visualizer module +---------------------------------- + +.. automodule:: detectron2.utils.visualizer + :members: + :undoc-members: + :show-inheritance: + +detectron2.utils.video\_visualizer module +----------------------------------------- + +.. automodule:: detectron2.utils.video_visualizer + :members: + :undoc-members: + :show-inheritance: + diff --git a/data_processing/detectron2/docs/notes/benchmarks.md b/data_processing/detectron2/docs/notes/benchmarks.md new file mode 100644 index 0000000..b41588d --- /dev/null +++ b/data_processing/detectron2/docs/notes/benchmarks.md @@ -0,0 +1,196 @@ + +# Benchmarks + +Here we benchmark the training speed of a Mask R-CNN in detectron2, +with some other popular open source Mask R-CNN implementations. + + +### Settings + +* Hardware: 8 NVIDIA V100s with NVLink. +* Software: Python 3.7, CUDA 10.1, cuDNN 7.6.5, PyTorch 1.5, + TensorFlow 1.15.0rc2, Keras 2.2.5, MxNet 1.6.0b20190820. +* Model: an end-to-end R-50-FPN Mask-RCNN model, using the same hyperparameter as the + [Detectron baseline config](https://github.com/facebookresearch/Detectron/blob/master/configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml) + (it does not have scale augmentation). +* Metrics: We use the average throughput in iterations 100-500 to skip GPU warmup time. + Note that for R-CNN-style models, the throughput of a model typically changes during training, because + it depends on the predictions of the model. Therefore this metric is not directly comparable with + "train speed" in model zoo, which is the average speed of the entire training run. + + +### Main Results + +```eval_rst ++-------------------------------+--------------------+ +| Implementation | Throughput (img/s) | ++===============================+====================+ +| |D2| |PT| | 62 | ++-------------------------------+--------------------+ +| mmdetection_ |PT| | 53 | ++-------------------------------+--------------------+ +| maskrcnn-benchmark_ |PT| | 53 | ++-------------------------------+--------------------+ +| tensorpack_ |TF| | 50 | ++-------------------------------+--------------------+ +| simpledet_ |mxnet| | 39 | ++-------------------------------+--------------------+ +| Detectron_ |C2| | 19 | ++-------------------------------+--------------------+ +| `matterport/Mask_RCNN`__ |TF| | 14 | ++-------------------------------+--------------------+ + +.. _maskrcnn-benchmark: https://github.com/facebookresearch/maskrcnn-benchmark/ +.. _tensorpack: https://github.com/tensorpack/tensorpack/tree/master/examples/FasterRCNN +.. _mmdetection: https://github.com/open-mmlab/mmdetection/ +.. _simpledet: https://github.com/TuSimple/simpledet/ +.. _Detectron: https://github.com/facebookresearch/Detectron +__ https://github.com/matterport/Mask_RCNN/ + +.. |D2| image:: https://github.com/facebookresearch/detectron2/raw/main/.github/Detectron2-Logo-Horz.svg?sanitize=true + :height: 15pt + :target: https://github.com/facebookresearch/detectron2/ +.. |PT| image:: https://pytorch.org/assets/images/logo-icon.svg + :width: 15pt + :height: 15pt + :target: https://pytorch.org +.. |TF| image:: https://static.nvidiagrid.net/ngc/containers/tensorflow.png + :width: 15pt + :height: 15pt + :target: https://tensorflow.org +.. |mxnet| image:: https://github.com/dmlc/web-data/raw/master/mxnet/image/mxnet_favicon.png + :width: 15pt + :height: 15pt + :target: https://mxnet.apache.org/ +.. |C2| image:: https://caffe2.ai/static/logo.svg + :width: 15pt + :height: 15pt + :target: https://caffe2.ai +``` + + +Details for each implementation: + +* __Detectron2__: with release v0.1.2, run: + ``` + python tools/train_net.py --config-file configs/Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x.yaml --num-gpus 8 + ``` + +* __mmdetection__: at commit `b0d845f`, run + ``` + ./tools/dist_train.sh configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_1x_coco.py 8 + ``` + +* __maskrcnn-benchmark__: use commit `0ce8f6f` with `sed -i 's/torch.uint8/torch.bool/g' **/*.py; sed -i 's/AT_CHECK/TORCH_CHECK/g' **/*.cu` + to make it compatible with PyTorch 1.5. Then, run training with + ``` + python -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/e2e_mask_rcnn_R_50_FPN_1x.yaml + ``` + The speed we observed is faster than its model zoo, likely due to different software versions. + +* __tensorpack__: at commit `caafda`, `export TF_CUDNN_USE_AUTOTUNE=0`, then run + ``` + mpirun -np 8 ./train.py --config DATA.BASEDIR=/data/coco TRAINER=horovod BACKBONE.STRIDE_1X1=True TRAIN.STEPS_PER_EPOCH=50 --load ImageNet-R50-AlignPadding.npz + ``` + +* __SimpleDet__: at commit `9187a1`, run + ``` + python detection_train.py --config config/mask_r50v1_fpn_1x.py + ``` + +* __Detectron__: run + ``` + python tools/train_net.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml + ``` + Note that many of its ops run on CPUs, therefore the performance is limited. + +* __matterport/Mask_RCNN__: at commit `3deaec`, apply the following diff, `export TF_CUDNN_USE_AUTOTUNE=0`, then run + ``` + python coco.py train --dataset=/data/coco/ --model=imagenet + ``` + Note that many small details in this implementation might be different + from Detectron's standards. + +
+ + (diff to make it use the same hyperparameters - click to expand) + + + ```diff + diff --git i/mrcnn/model.py w/mrcnn/model.py + index 62cb2b0..61d7779 100644 + --- i/mrcnn/model.py + +++ w/mrcnn/model.py + @@ -2367,8 +2367,8 @@ class MaskRCNN(): + epochs=epochs, + steps_per_epoch=self.config.STEPS_PER_EPOCH, + callbacks=callbacks, + - validation_data=val_generator, + - validation_steps=self.config.VALIDATION_STEPS, + + #validation_data=val_generator, + + #validation_steps=self.config.VALIDATION_STEPS, + max_queue_size=100, + workers=workers, + use_multiprocessing=True, + diff --git i/mrcnn/parallel_model.py w/mrcnn/parallel_model.py + index d2bf53b..060172a 100644 + --- i/mrcnn/parallel_model.py + +++ w/mrcnn/parallel_model.py + @@ -32,6 +32,7 @@ class ParallelModel(KM.Model): + keras_model: The Keras model to parallelize + gpu_count: Number of GPUs. Must be > 1 + """ + + super().__init__() + self.inner_model = keras_model + self.gpu_count = gpu_count + merged_outputs = self.make_parallel() + diff --git i/samples/coco/coco.py w/samples/coco/coco.py + index 5d172b5..239ed75 100644 + --- i/samples/coco/coco.py + +++ w/samples/coco/coco.py + @@ -81,7 +81,10 @@ class CocoConfig(Config): + IMAGES_PER_GPU = 2 + + # Uncomment to train on 8 GPUs (default is 1) + - # GPU_COUNT = 8 + + GPU_COUNT = 8 + + BACKBONE = "resnet50" + + STEPS_PER_EPOCH = 50 + + TRAIN_ROIS_PER_IMAGE = 512 + + # Number of classes (including background) + NUM_CLASSES = 1 + 80 # COCO has 80 classes + @@ -496,29 +499,10 @@ if __name__ == '__main__': + # *** This training schedule is an example. Update to your needs *** + + # Training - Stage 1 + - print("Training network heads") + model.train(dataset_train, dataset_val, + learning_rate=config.LEARNING_RATE, + epochs=40, + - layers='heads', + - augmentation=augmentation) + - + - # Training - Stage 2 + - # Finetune layers from ResNet stage 4 and up + - print("Fine tune Resnet stage 4 and up") + - model.train(dataset_train, dataset_val, + - learning_rate=config.LEARNING_RATE, + - epochs=120, + - layers='4+', + - augmentation=augmentation) + - + - # Training - Stage 3 + - # Fine tune all layers + - print("Fine tune all layers") + - model.train(dataset_train, dataset_val, + - learning_rate=config.LEARNING_RATE / 10, + - epochs=160, + - layers='all', + + layers='3+', + augmentation=augmentation) + + elif args.command == "evaluate": + ``` + +
diff --git a/data_processing/detectron2/docs/notes/changelog.md b/data_processing/detectron2/docs/notes/changelog.md new file mode 100644 index 0000000..000e9f8 --- /dev/null +++ b/data_processing/detectron2/docs/notes/changelog.md @@ -0,0 +1,48 @@ +# Change Log and Backward Compatibility + +### Releases +See release logs at +[https://github.com/facebookresearch/detectron2/releases](https://github.com/facebookresearch/detectron2/releases) +for new updates. + +### Backward Compatibility + +Due to the research nature of what the library does, there might be backward incompatible changes. +But we try to reduce users' disruption by the following ways: +* APIs listed in [API documentation](https://detectron2.readthedocs.io/modules/index.html), including + function/class names, their arguments, and documented class attributes, are considered *stable* unless + otherwise noted in the documentation. + They are less likely to be broken, but if needed, will trigger a deprecation warning for a reasonable period + before getting broken, and will be documented in release logs. +* Others functions/classses/attributes are considered internal, and are more likely to change. + However, we're aware that some of them may be already used by other projects, and in particular we may + use them for convenience among projects under `detectron2/projects`. + For such APIs, we may treat them as stable APIs and also apply the above strategies. + They may be promoted to stable when we're ready. +* Projects under "detectron2/projects" or imported with "detectron2.projects" are research projects + and are all considered experimental. +* Classes/functions that contain the word "default" or are explicitly documented to produce + "default behavior" may change their behaviors when new features are added. + +Despite of the possible breakage, if a third-party project would like to keep up with the latest updates +in detectron2, using it as a library will still be less disruptive than forking, because +the frequency and scope of API changes will be much smaller than code changes. + +To see such changes, search for "incompatible changes" in [release logs](https://github.com/facebookresearch/detectron2/releases). + +### Config Version Change Log + +Detectron2's config version has not been changed since open source. +There is no need for an open source user to worry about this. + +* v1: Rename `RPN_HEAD.NAME` to `RPN.HEAD_NAME`. +* v2: A batch of rename of many configurations before release. + +### Silent Regressions in Historical Versions: + +We list a few silent regressions, since they may silently produce incorrect results and will be hard to debug. + +* 04/01/2020 - 05/11/2020: Bad accuracy if `TRAIN_ON_PRED_BOXES` is set to True. +* 03/30/2020 - 04/01/2020: ResNets are not correctly built. +* 12/19/2019 - 12/26/2019: Using aspect ratio grouping causes a drop in accuracy. +* - 11/9/2019: Test time augmentation does not predict the last category. diff --git a/data_processing/detectron2/docs/notes/compatibility.md b/data_processing/detectron2/docs/notes/compatibility.md new file mode 100644 index 0000000..83d93f5 --- /dev/null +++ b/data_processing/detectron2/docs/notes/compatibility.md @@ -0,0 +1,84 @@ +# Compatibility with Other Libraries + +## Compatibility with Detectron (and maskrcnn-benchmark) + +Detectron2 addresses some legacy issues left in Detectron. As a result, their models +are not compatible: +running inference with the same model weights will produce different results in the two code bases. + +The major differences regarding inference are: + +- The height and width of a box with corners (x1, y1) and (x2, y2) is now computed more naturally as + width = x2 - x1 and height = y2 - y1; + In Detectron, a "+ 1" was added both height and width. + + Note that the relevant ops in Caffe2 have [adopted this change of convention](https://github.com/pytorch/pytorch/pull/20550) + with an extra option. + So it is still possible to run inference with a Detectron2-trained model in Caffe2. + + The change in height/width calculations most notably changes: + - encoding/decoding in bounding box regression. + - non-maximum suppression. The effect here is very negligible, though. + +- RPN now uses simpler anchors with fewer quantization artifacts. + + In Detectron, the anchors were quantized and + [do not have accurate areas](https://github.com/facebookresearch/Detectron/issues/227). + In Detectron2, the anchors are center-aligned to feature grid points and not quantized. + +- Classification layers have a different ordering of class labels. + + This involves any trainable parameter with shape (..., num_categories + 1, ...). + In Detectron2, integer labels [0, K-1] correspond to the K = num_categories object categories + and the label "K" corresponds to the special "background" category. + In Detectron, label "0" means background, and labels [1, K] correspond to the K categories. + +- ROIAlign is implemented differently. The new implementation is [available in Caffe2](https://github.com/pytorch/pytorch/pull/23706). + + 1. All the ROIs are shifted by half a pixel compared to Detectron in order to create better image-feature-map alignment. + See `layers/roi_align.py` for details. + To enable the old behavior, use `ROIAlign(aligned=False)`, or `POOLER_TYPE=ROIAlign` instead of + `ROIAlignV2` (the default). + + 1. The ROIs are not required to have a minimum size of 1. + This will lead to tiny differences in the output, but should be negligible. + +- Mask inference function is different. + + In Detectron2, the "paste_mask" function is different and should be more accurate than in Detectron. This change + can improve mask AP on COCO by ~0.5% absolute. + +There are some other differences in training as well, but they won't affect +model-level compatibility. The major ones are: + +- We fixed a [bug](https://github.com/facebookresearch/Detectron/issues/459) in + Detectron, by making `RPN.POST_NMS_TOPK_TRAIN` per-image, rather than per-batch. + The fix may lead to a small accuracy drop for a few models (e.g. keypoint + detection) and will require some parameter tuning to match the Detectron results. +- For simplicity, we change the default loss in bounding box regression to L1 loss, instead of smooth L1 loss. + We have observed that this tends to slightly decrease box AP50 while improving box AP for higher + overlap thresholds (and leading to a slight overall improvement in box AP). +- We interpret the coordinates in COCO bounding box and segmentation annotations + as coordinates in range `[0, width]` or `[0, height]`. The coordinates in + COCO keypoint annotations are interpreted as pixel indices in range `[0, width - 1]` or `[0, height - 1]`. + Note that this affects how flip augmentation is implemented. + + +[This article](https://ppwwyyxx.com/blog/2021/Where-are-Pixels/) +explains more details on the above mentioned issues +about pixels, coordinates, and "+1"s. + + +## Compatibility with Caffe2 + +As mentioned above, despite the incompatibilities with Detectron, the relevant +ops have been implemented in Caffe2. +Therefore, models trained with detectron2 can be converted in Caffe2. +See [Deployment](../tutorials/deployment.md) for the tutorial. + +## Compatibility with TensorFlow + +Most ops are available in TensorFlow, although some tiny differences in +the implementation of resize / ROIAlign / padding need to be addressed. +A working conversion script is provided by [tensorpack Faster R-CNN](https://github.com/tensorpack/tensorpack/tree/master/examples/FasterRCNN/convert_d2) +to run a standard detectron2 model in TensorFlow. diff --git a/data_processing/detectron2/docs/notes/contributing.md b/data_processing/detectron2/docs/notes/contributing.md new file mode 100644 index 0000000..9518123 --- /dev/null +++ b/data_processing/detectron2/docs/notes/contributing.md @@ -0,0 +1 @@ +../../.github/CONTRIBUTING.md \ No newline at end of file diff --git a/data_processing/detectron2/docs/notes/index.rst b/data_processing/detectron2/docs/notes/index.rst new file mode 100644 index 0000000..63cf907 --- /dev/null +++ b/data_processing/detectron2/docs/notes/index.rst @@ -0,0 +1,10 @@ +Notes +====================================== + +.. toctree:: + :maxdepth: 2 + + benchmarks + compatibility + contributing + changelog diff --git a/data_processing/detectron2/docs/requirements.txt b/data_processing/detectron2/docs/requirements.txt new file mode 100644 index 0000000..720a1b1 --- /dev/null +++ b/data_processing/detectron2/docs/requirements.txt @@ -0,0 +1,23 @@ +docutils==0.16 +# https://github.com/sphinx-doc/sphinx/commit/7acd3ada3f38076af7b2b5c9f3b60bb9c2587a3d +sphinx==3.2.0 +recommonmark==0.6.0 +sphinx_rtd_theme +# Dependencies here are only those required by import +termcolor +numpy +tqdm +matplotlib +termcolor +yacs +tabulate +cloudpickle +Pillow +future +git+https://github.com/facebookresearch/fvcore.git +https://download.pytorch.org/whl/cpu/torch-1.8.1%2Bcpu-cp37-cp37m-linux_x86_64.whl +https://download.pytorch.org/whl/cpu/torchvision-0.9.1%2Bcpu-cp37-cp37m-linux_x86_64.whl +omegaconf>=2.1.0.dev24 +hydra-core>=1.1.0.dev5 +scipy +timm diff --git a/data_processing/detectron2/docs/tutorials/README.md b/data_processing/detectron2/docs/tutorials/README.md new file mode 100644 index 0000000..1ca9c94 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/README.md @@ -0,0 +1,4 @@ +# Read the docs: + +The latest documentation built from this directory is available at [detectron2.readthedocs.io](https://detectron2.readthedocs.io/). +Documents in this directory are not meant to be read on github. diff --git a/data_processing/detectron2/docs/tutorials/augmentation.md b/data_processing/detectron2/docs/tutorials/augmentation.md new file mode 100644 index 0000000..7601a08 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/augmentation.md @@ -0,0 +1,186 @@ + +# Data Augmentation + +Augmentation is an important part of training. +Detectron2's data augmentation system aims at addressing the following goals: + +1. Allow augmenting multiple data types together + (e.g., images together with their bounding boxes and masks) +2. Allow applying a sequence of statically-declared augmentation +3. Allow adding custom new data types to augment (rotated bounding boxes, video clips, etc.) +4. Process and manipulate the __operations__ that are applied by augmentations + +The first two features cover most of the common use cases, and is also +available in other libraries such as [albumentations](https://medium.com/pytorch/multi-target-in-albumentations-16a777e9006e). +Supporting other features adds some overhead to detectron2's augmentation API, +which we'll explain in this tutorial. + +This tutorial focuses on how to use augmentations when writing new data loaders, +and how to write new augmentations. +If you use the default data loader in detectron2, it already supports taking a user-provided list of custom augmentations, +as explained in the [Dataloader tutorial](data_loading). + +## Basic Usage + +The basic usage of feature (1) and (2) is like the following: +```python +from detectron2.data import transforms as T +# Define a sequence of augmentations: +augs = T.AugmentationList([ + T.RandomBrightness(0.9, 1.1), + T.RandomFlip(prob=0.5), + T.RandomCrop("absolute", (640, 640)) +]) # type: T.Augmentation + +# Define the augmentation input ("image" required, others optional): +input = T.AugInput(image, boxes=boxes, sem_seg=sem_seg) +# Apply the augmentation: +transform = augs(input) # type: T.Transform +image_transformed = input.image # new image +sem_seg_transformed = input.sem_seg # new semantic segmentation + +# For any extra data that needs to be augmented together, use transform, e.g.: +image2_transformed = transform.apply_image(image2) +polygons_transformed = transform.apply_polygons(polygons) +``` + +Three basic concepts are involved here. They are: +* [T.Augmentation](../modules/data_transforms.html#detectron2.data.transforms.Augmentation) defines the __"policy"__ to modify inputs. + * its `__call__(AugInput) -> Transform` method augments the inputs in-place, and returns the operation that is applied +* [T.Transform](../modules/data_transforms.html#detectron2.data.transforms.Transform) + implements the actual __operations__ to transform data + * it has methods such as `apply_image`, `apply_coords` that define how to transform each data type +* [T.AugInput](../modules/data_transforms.html#detectron2.data.transforms.AugInput) + stores inputs needed by `T.Augmentation` and how they should be transformed. + This concept is needed for some advanced usage. + Using this class directly should be sufficient for all common use cases, + since extra data not in `T.AugInput` can be augmented using the returned + `transform`, as shown in the above example. + +## Write New Augmentations + +Most 2D augmentations only need to know about the input image. Such augmentation can be implemented easily like this: + +```python +class MyColorAugmentation(T.Augmentation): + def get_transform(self, image): + r = np.random.rand(2) + return T.ColorTransform(lambda x: x * r[0] + r[1] * 10) + +class MyCustomResize(T.Augmentation): + def get_transform(self, image): + old_h, old_w = image.shape[:2] + new_h, new_w = int(old_h * np.random.rand()), int(old_w * 1.5) + return T.ResizeTransform(old_h, old_w, new_h, new_w) + +augs = MyCustomResize() +transform = augs(input) +``` + +In addition to image, any attributes of the given `AugInput` can be used as long +as they are part of the function signature, e.g.: + +```python +class MyCustomCrop(T.Augmentation): + def get_transform(self, image, sem_seg): + # decide where to crop using both image and sem_seg + return T.CropTransform(...) + +augs = MyCustomCrop() +assert hasattr(input, "image") and hasattr(input, "sem_seg") +transform = augs(input) +``` + +New transform operation can also be added by subclassing +[T.Transform](../modules/data_transforms.html#detectron2.data.transforms.Transform). + +## Advanced Usage + +We give a few examples of advanced usages that +are enabled by our system. +These options can be interesting to new research, +although changing them is often not needed +for standard use cases. + +### Custom transform strategy + +Instead of only returning the augmented data, detectron2's `Augmentation` returns the __operations__ as `T.Transform`. +This allows users to apply custom transform strategy on their data. +We use keypoints data as an example. + +Keypoints are (x, y) coordinates, but they are not so trivial to augment due to the semantic meaning they carry. +Such meaning is only known to the users, therefore users may want to augment them manually +by looking at the returned `transform`. +For example, when an image is horizontally flipped, we'd like to swap the keypoint annotations for "left eye" and "right eye". +This can be done like this (included by default in detectron2's default data loader): +```python +# augs, input are defined as in previous examples +transform = augs(input) # type: T.Transform +keypoints_xy = transform.apply_coords(keypoints_xy) # transform the coordinates + +# get a list of all transforms that were applied +transforms = T.TransformList([transform]).transforms +# check if it is flipped for odd number of times +do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms) % 2 == 1 +if do_hflip: + keypoints_xy = keypoints_xy[flip_indices_mapping] +``` + +As another example, keypoints annotations often have a "visibility" field. +A sequence of augmentations might augment a visible keypoint out of the image boundary (e.g. with cropping), +but then bring it back within the boundary afterwards (e.g. with image padding). +If users decide to label such keypoints "invisible", +then the visibility check has to happen after every transform step. +This can be achieved by: + +```python +transform = augs(input) # type: T.TransformList +assert isinstance(transform, T.TransformList) +for t in transform.transforms: + keypoints_xy = t.apply_coords(keypoints_xy) + visibility &= (keypoints_xy >= [0, 0] & keypoints_xy <= [W, H]).all(axis=1) + +# btw, detectron2's `transform_keypoint_annotations` function chooses to label such keypoints "visible": +# keypoints_xy = transform.apply_coords(keypoints_xy) +# visibility &= (keypoints_xy >= [0, 0] & keypoints_xy <= [W, H]).all(axis=1) +``` + + +### Geometrically invert the transform +If images are pre-processed by augmentations before inference, the predicted results +such as segmentation masks are localized on the augmented image. +We'd like to invert the applied augmentation with the [inverse()](../modules/data_transforms.html#detectron2.data.transforms.Transform.inverse) +API, to obtain results on the original image: +```python +transform = augs(input) +pred_mask = make_prediction(input.image) +inv_transform = transform.inverse() +pred_mask_orig = inv_transform.apply_segmentation(pred_mask) +``` + +### Add new data types + +[T.Transform](../modules/data_transforms.html#detectron2.data.transforms.Transform) +supports a few common data types to transform, including images, coordinates, masks, boxes, polygons. +It allows registering new data types, e.g.: +```python +@T.HFlipTransform.register_type("rotated_boxes") +def func(flip_transform: T.HFlipTransform, rotated_boxes: Any): + # do the work + return flipped_rotated_boxes + +t = HFlipTransform(width=800) +transformed_rotated_boxes = t.apply_rotated_boxes(rotated_boxes) # func will be called +``` + +### Extend T.AugInput + +An augmentation can only access attributes available in the given input. +[T.AugInput](../modules/data_transforms.html#detectron2.data.transforms.StandardAugInput) defines "image", "boxes", "sem_seg", +which are sufficient for common augmentation strategies to decide how to augment. +If not, a custom implementation is needed. + +By re-implement the "transform()" method in AugInput, it is also possible to +augment different fields in ways that are dependent on each other. +Such use case is uncommon (e.g. post-process bounding box based on augmented masks), but allowed by the system. + diff --git a/data_processing/detectron2/docs/tutorials/builtin_datasets.md b/data_processing/detectron2/docs/tutorials/builtin_datasets.md new file mode 100644 index 0000000..0ba8242 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/builtin_datasets.md @@ -0,0 +1 @@ +../../datasets/README.md \ No newline at end of file diff --git a/data_processing/detectron2/docs/tutorials/configs.md b/data_processing/detectron2/docs/tutorials/configs.md new file mode 100644 index 0000000..49538d0 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/configs.md @@ -0,0 +1,62 @@ +# Yacs Configs + +Detectron2 provides a key-value based config system that can be +used to obtain standard, common behaviors. + +This system uses YAML and [yacs](https://github.com/rbgirshick/yacs). +Yaml is a very limited language, +so we do not expect all features in detectron2 to be available through configs. +If you need something that's not available in the config space, +please write code using detectron2's API. + +With the introduction of a more powerful [LazyConfig system](lazyconfigs.md), +we no longer add functionality / new keys to the Yacs/Yaml-based config system. + +### Basic Usage + +Some basic usage of the `CfgNode` object is shown here. See more in [documentation](../modules/config.html#detectron2.config.CfgNode). +```python +from detectron2.config import get_cfg +cfg = get_cfg() # obtain detectron2's default config +cfg.xxx = yyy # add new configs for your own custom components +cfg.merge_from_file("my_cfg.yaml") # load values from a file + +cfg.merge_from_list(["MODEL.WEIGHTS", "weights.pth"]) # can also load values from a list of str +print(cfg.dump()) # print formatted configs +with open("output.yaml", "w") as f: + f.write(cfg.dump()) # save config to file +``` + +In addition to the basic Yaml syntax, the config file can +define a `_BASE_: base.yaml` field, which will load a base config file first. +Values in the base config will be overwritten in sub-configs, if there are any conflicts. +We provided several base configs for standard model architectures. + +Many builtin tools in detectron2 accept command line config overwrite: +Key-value pairs provided in the command line will overwrite the existing values in the config file. +For example, [demo.py](../../demo/demo.py) can be used with +```sh +./demo.py --config-file config.yaml [--other-options] \ + --opts MODEL.WEIGHTS /path/to/weights INPUT.MIN_SIZE_TEST 1000 +``` + +To see a list of available configs in detectron2 and what they mean, +check [Config References](../modules/config.html#config-references) + +### Configs in Projects + +A project that lives outside the detectron2 library may define its own configs, which will need to be added +for the project to be functional, e.g.: +```python +from detectron2.projects.point_rend import add_pointrend_config +cfg = get_cfg() # obtain detectron2's default config +add_pointrend_config(cfg) # add pointrend's default config +# ... ... +``` + +### Best Practice with Configs + +1. Treat the configs you write as "code": avoid copying them or duplicating them; use `_BASE_` + to share common parts between configs. + +2. Keep the configs you write simple: don't include keys that do not affect the experimental setting. diff --git a/data_processing/detectron2/docs/tutorials/data_loading.md b/data_processing/detectron2/docs/tutorials/data_loading.md new file mode 100644 index 0000000..1d2769f --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/data_loading.md @@ -0,0 +1,95 @@ + +# Dataloader + +Dataloader is the component that provides data to models. +A dataloader usually (but not necessarily) takes raw information from [datasets](./datasets.md), +and process them into a format needed by the model. + +## How the Existing Dataloader Works + +Detectron2 contains a builtin data loading pipeline. +It's good to understand how it works, in case you need to write a custom one. + +Detectron2 provides two functions +[build_detection_{train,test}_loader](../modules/data.html#detectron2.data.build_detection_train_loader) +that create a default data loader from a given config. +Here is how `build_detection_{train,test}_loader` work: + +1. It takes the name of a registered dataset (e.g., "coco_2017_train") and loads a `list[dict]` representing the dataset items + in a lightweight format. These dataset items are not yet ready to be used by the model (e.g., images are + not loaded into memory, random augmentations have not been applied, etc.). + Details about the dataset format and dataset registration can be found in + [datasets](./datasets.md). +2. Each dict in this list is mapped by a function ("mapper"): + * Users can customize this mapping function by specifying the "mapper" argument in + `build_detection_{train,test}_loader`. The default mapper is [DatasetMapper](../modules/data.html#detectron2.data.DatasetMapper). + * The output format of the mapper can be arbitrary, as long as it is accepted by the consumer of this data loader (usually the model). + The outputs of the default mapper, after batching, follow the default model input format documented in + [Use Models](./models.html#model-input-format). + * The role of the mapper is to transform the lightweight representation of a dataset item into a format + that is ready for the model to consume (including, e.g., read images, perform random data augmentation and convert to torch Tensors). + If you would like to perform custom transformations to data, you often want a custom mapper. +3. The outputs of the mapper are batched (simply into a list). +4. This batched data is the output of the data loader. Typically, it's also the input of + `model.forward()`. + + +## Write a Custom Dataloader + +Using a different "mapper" with `build_detection_{train,test}_loader(mapper=)` works for most use cases +of custom data loading. +For example, if you want to resize all images to a fixed size for training, use: + +```python +import detectron2.data.transforms as T +from detectron2.data import DatasetMapper # the default mapper +dataloader = build_detection_train_loader(cfg, + mapper=DatasetMapper(cfg, is_train=True, augmentations=[ + T.Resize((800, 800)) + ])) +# use this dataloader instead of the default +``` +If the arguments of the default [DatasetMapper](../modules/data.html#detectron2.data.DatasetMapper) +does not provide what you need, you may write a custom mapper function and use it instead, e.g.: + +```python +from detectron2.data import detection_utils as utils + # Show how to implement a minimal mapper, similar to the default DatasetMapper +def mapper(dataset_dict): + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # can use other ways to read image + image = utils.read_image(dataset_dict["file_name"], format="BGR") + # See "Data Augmentation" tutorial for details usage + auginput = T.AugInput(image) + transform = T.Resize((800, 800))(auginput) + image = torch.from_numpy(auginput.image.transpose(2, 0, 1)) + annos = [ + utils.transform_instance_annotations(annotation, [transform], image.shape[1:]) + for annotation in dataset_dict.pop("annotations") + ] + return { + # create the format that the model expects + "image": image, + "instances": utils.annotations_to_instances(annos, image.shape[1:]) + } +dataloader = build_detection_train_loader(cfg, mapper=mapper) +``` + +If you want to change not only the mapper (e.g., in order to implement different sampling or batching logic), +`build_detection_train_loader` won't work and you will need to write a different data loader. +The data loader is simply a +python iterator that produces [the format](./models.md) that the model accepts. +You can implement it using any tools you like. + +No matter what to implement, it's recommended to +check out [API documentation of detectron2.data](../modules/data) to learn more about the APIs of +these functions. + +## Use a Custom Dataloader + +If you use [DefaultTrainer](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer), +you can overwrite its `build_{train,test}_loader` method to use your own dataloader. +See the [deeplab dataloader](../../projects/DeepLab/train_net.py) +for an example. + +If you write your own training loop, you can plug in your data loader easily. diff --git a/data_processing/detectron2/docs/tutorials/datasets.md b/data_processing/detectron2/docs/tutorials/datasets.md new file mode 100644 index 0000000..91103f6 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/datasets.md @@ -0,0 +1,290 @@ +# Use Custom Datasets + +This document explains how the dataset APIs +([DatasetCatalog](../modules/data.html#detectron2.data.DatasetCatalog), [MetadataCatalog](../modules/data.html#detectron2.data.MetadataCatalog)) +work, and how to use them to add custom datasets. + +Datasets that have builtin support in detectron2 are listed in [builtin datasets](builtin_datasets.md). +If you want to use a custom dataset while also reusing detectron2's data loaders, +you will need to: + +1. __Register__ your dataset (i.e., tell detectron2 how to obtain your dataset). +2. Optionally, __register metadata__ for your dataset. + +Next, we explain the above two concepts in detail. + +The [Colab tutorial](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) +has a live example of how to register and train on a dataset of custom formats. + +### Register a Dataset + +To let detectron2 know how to obtain a dataset named "my_dataset", users need to implement +a function that returns the items in your dataset and then tell detectron2 about this +function: +```python +def my_dataset_function(): + ... + return list[dict] in the following format + +from detectron2.data import DatasetCatalog +DatasetCatalog.register("my_dataset", my_dataset_function) +# later, to access the data: +data: List[Dict] = DatasetCatalog.get("my_dataset") +``` + +Here, the snippet associates a dataset named "my_dataset" with a function that returns the data. +The function must return the same data (with same order) if called multiple times. +The registration stays effective until the process exits. + +The function can do arbitrary things and should return the data in `list[dict]`, each dict in either +of the following formats: +1. Detectron2's standard dataset dict, described below. This will make it work with many other builtin + features in detectron2, so it's recommended to use it when it's sufficient. +2. Any custom format. You can also return arbitrary dicts in your own format, + such as adding extra keys for new tasks. + Then you will need to handle them properly downstream as well. + See below for more details. + +#### Standard Dataset Dicts + +For standard tasks +(instance detection, instance/semantic/panoptic segmentation, keypoint detection), +we load the original dataset into `list[dict]` with a specification similar to COCO's annotations. +This is our standard representation for a dataset. + +Each dict contains information about one image. +The dict may have the following fields, +and the required fields vary based on what the dataloader or the task needs (see more below). + +```eval_rst +.. list-table:: + :header-rows: 1 + + * - Task + - Fields + * - Common + - file_name, height, width, image_id + + * - Instance detection/segmentation + - annotations + + * - Semantic segmentation + - sem_seg_file_name + + * - Panoptic segmentation + - pan_seg_file_name, segments_info +``` + ++ `file_name`: the full path to the image file. ++ `height`, `width`: integer. The shape of the image. ++ `image_id` (str or int): a unique id that identifies this image. Required by many + evaluators to identify the images, but a dataset may use it for different purposes. ++ `annotations` (list[dict]): Required by __instance detection/segmentation or keypoint detection__ tasks. + Each dict corresponds to annotations of one instance in this image, and + may contain the following keys: + + `bbox` (list[float], required): list of 4 numbers representing the bounding box of the instance. + + `bbox_mode` (int, required): the format of bbox. It must be a member of + [structures.BoxMode](../modules/structures.html#detectron2.structures.BoxMode). + Currently supports: `BoxMode.XYXY_ABS`, `BoxMode.XYWH_ABS`. + + `category_id` (int, required): an integer in the range [0, num_categories-1] representing the category label. + The value num_categories is reserved to represent the "background" category, if applicable. + + `segmentation` (list[list[float]] or dict): the segmentation mask of the instance. + + If `list[list[float]]`, it represents a list of polygons, one for each connected component + of the object. Each `list[float]` is one simple polygon in the format of `[x1, y1, ..., xn, yn]` (n≥3). + The Xs and Ys are absolute coordinates in unit of pixels. + + If `dict`, it represents the per-pixel segmentation mask in COCO's compressed RLE format. + The dict should have keys "size" and "counts". You can convert a uint8 segmentation mask of 0s and + 1s into such dict by `pycocotools.mask.encode(np.asarray(mask, order="F"))`. + `cfg.INPUT.MASK_FORMAT` must be set to `bitmask` if using the default data loader with such format. + + `keypoints` (list[float]): in the format of [x1, y1, v1,..., xn, yn, vn]. + v[i] means the [visibility](http://cocodataset.org/#format-data) of this keypoint. + `n` must be equal to the number of keypoint categories. + The Xs and Ys are absolute real-value coordinates in range [0, W or H]. + + (Note that the keypoint coordinates in COCO format are integers in range [0, W-1 or H-1], which is different + from our standard format. Detectron2 adds 0.5 to COCO keypoint coordinates to convert them from discrete + pixel indices to floating point coordinates.) + + `iscrowd`: 0 (default) or 1. Whether this instance is labeled as COCO's "crowd + region". Don't include this field if you don't know what it means. + + If `annotations` is an empty list, it means the image is labeled to have no objects. + Such images will by default be removed from training, + but can be included using `DATALOADER.FILTER_EMPTY_ANNOTATIONS`. + ++ `sem_seg_file_name` (str): + The full path to the semantic segmentation ground truth file. + It should be a grayscale image whose pixel values are integer labels. ++ `pan_seg_file_name` (str): + The full path to panoptic segmentation ground truth file. + It should be an RGB image whose pixel values are integer ids encoded using the + [panopticapi.utils.id2rgb](https://github.com/cocodataset/panopticapi/) function. + The ids are defined by `segments_info`. + If an id does not appear in `segments_info`, the pixel is considered unlabeled + and is usually ignored in training & evaluation. ++ `segments_info` (list[dict]): defines the meaning of each id in panoptic segmentation ground truth. + Each dict has the following keys: + + `id` (int): integer that appears in the ground truth image. + + `category_id` (int): an integer in the range [0, num_categories-1] representing the category label. + + `iscrowd`: 0 (default) or 1. Whether this instance is labeled as COCO's "crowd region". + + +```eval_rst + +.. note:: + + The PanopticFPN model does not use the panoptic segmentation + format defined here, but a combination of both instance segmentation and semantic segmentation data + format. See :doc:`builtin_datasets` for instructions on COCO. + +``` + +Fast R-CNN (with pre-computed proposals) models are rarely used today. +To train a Fast R-CNN, the following extra keys are needed: + ++ `proposal_boxes` (array): 2D numpy array with shape (K, 4) representing K precomputed proposal boxes for this image. ++ `proposal_objectness_logits` (array): numpy array with shape (K, ), which corresponds to the objectness + logits of proposals in 'proposal_boxes'. ++ `proposal_bbox_mode` (int): the format of the precomputed proposal bbox. + It must be a member of + [structures.BoxMode](../modules/structures.html#detectron2.structures.BoxMode). + Default is `BoxMode.XYXY_ABS`. + + + +#### Custom Dataset Dicts for New Tasks + +In the `list[dict]` that your dataset function returns, the dictionary can also have __arbitrary custom data__. +This will be useful for a new task that needs extra information not covered +by the standard dataset dicts. In this case, you need to make sure the downstream code can handle your data +correctly. Usually this requires writing a new `mapper` for the dataloader (see [Use Custom Dataloaders](./data_loading.md)). + +When designing a custom format, note that all dicts are stored in memory +(sometimes serialized and with multiple copies). +To save memory, each dict is meant to contain __small__ but sufficient information +about each sample, such as file names and annotations. +Loading full samples typically happens in the data loader. + +For attributes shared among the entire dataset, use `Metadata` (see below). +To avoid extra memory, do not save such information inside each sample. + +### "Metadata" for Datasets + +Each dataset is associated with some metadata, accessible through +`MetadataCatalog.get(dataset_name).some_metadata`. +Metadata is a key-value mapping that contains information that's shared among +the entire dataset, and usually is used to interpret what's in the dataset, e.g., +names of classes, colors of classes, root of files, etc. +This information will be useful for augmentation, evaluation, visualization, logging, etc. +The structure of metadata depends on what is needed from the corresponding downstream code. + +If you register a new dataset through `DatasetCatalog.register`, +you may also want to add its corresponding metadata through +`MetadataCatalog.get(dataset_name).some_key = some_value`, to enable any features that need the metadata. +You can do it like this (using the metadata key "thing_classes" as an example): + +```python +from detectron2.data import MetadataCatalog +MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"] +``` + +Here is a list of metadata keys that are used by builtin features in detectron2. +If you add your own dataset without these metadata, some features may be +unavailable to you: + +* `thing_classes` (list[str]): Used by all instance detection/segmentation tasks. + A list of names for each instance/thing category. + If you load a COCO format dataset, it will be automatically set by the function `load_coco_json`. + +* `thing_colors` (list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each thing category. + Used for visualization. If not given, random colors will be used. + +* `stuff_classes` (list[str]): Used by semantic and panoptic segmentation tasks. + A list of names for each stuff category. + +* `stuff_colors` (list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each stuff category. + Used for visualization. If not given, random colors are used. + +* `ignore_label` (int): Used by semantic and panoptic segmentation tasks. Pixels in ground-truth + annotations with this category label should be ignored in evaluation. Typically these are "unlabeled" + pixels. + +* `keypoint_names` (list[str]): Used by keypoint detection. A list of names for each keypoint. + +* `keypoint_flip_map` (list[tuple[str]]): Used by keypoint detection. A list of pairs of names, + where each pair are the two keypoints that should be flipped if the image is + flipped horizontally during augmentation. +* `keypoint_connection_rules`: list[tuple(str, str, (r, g, b))]. Each tuple specifies a pair of keypoints + that are connected and the color (in [0, 255]) to use for the line between them when visualized. + +Some additional metadata that are specific to the evaluation of certain datasets (e.g. COCO): + +* `thing_dataset_id_to_contiguous_id` (dict[int->int]): Used by all instance detection/segmentation tasks in the COCO format. + A mapping from instance class ids in the dataset to contiguous ids in range [0, #class). + Will be automatically set by the function `load_coco_json`. + +* `stuff_dataset_id_to_contiguous_id` (dict[int->int]): Used when generating prediction json files for + semantic/panoptic segmentation. + A mapping from semantic segmentation class ids in the dataset + to contiguous ids in [0, num_categories). It is useful for evaluation only. + +* `json_file`: The COCO annotation json file. Used by COCO evaluation for COCO-format datasets. +* `panoptic_root`, `panoptic_json`: Used by COCO-format panoptic evaluation. +* `evaluator_type`: Used by the builtin main training script to select + evaluator. Don't use it in a new training script. + You can just provide the [DatasetEvaluator](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluator) + for your dataset directly in your main script. + +```eval_rst +.. note:: + + In recognition, sometimes we use the term "thing" for instance-level tasks, + and "stuff" for semantic segmentation tasks. + Both are used in panoptic segmentation tasks. + For background on the concept of "thing" and "stuff", see + `On Seeing Stuff: The Perception of Materials by Humans and Machines + `_. +``` + +### Register a COCO Format Dataset + +If your instance-level (detection, segmentation, keypoint) dataset is already a json file in the COCO format, +the dataset and its associated metadata can be registered easily with: +```python +from detectron2.data.datasets import register_coco_instances +register_coco_instances("my_dataset", {}, "json_annotation.json", "path/to/image/dir") +``` + +If your dataset is in COCO format but need to be further processed, or has extra custom per-instance annotations, +the [load_coco_json](../modules/data.html#detectron2.data.datasets.load_coco_json) +function might be useful. + +### Update the Config for New Datasets + +Once you've registered the dataset, you can use the name of the dataset (e.g., "my_dataset" in +example above) in `cfg.DATASETS.{TRAIN,TEST}`. +There are other configs you might want to change to train or evaluate on new datasets: + +* `MODEL.ROI_HEADS.NUM_CLASSES` and `MODEL.RETINANET.NUM_CLASSES` are the number of thing classes + for R-CNN and RetinaNet models, respectively. +* `MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS` sets the number of keypoints for Keypoint R-CNN. + You'll also need to set [Keypoint OKS](http://cocodataset.org/#keypoints-eval) + with `TEST.KEYPOINT_OKS_SIGMAS` for evaluation. +* `MODEL.SEM_SEG_HEAD.NUM_CLASSES` sets the number of stuff classes for Semantic FPN & Panoptic FPN. +* `TEST.DETECTIONS_PER_IMAGE` controls the maximum number of objects to be detected. + Set it to a larger number if test images may contain >100 objects. +* If you're training Fast R-CNN (with precomputed proposals), `DATASETS.PROPOSAL_FILES_{TRAIN,TEST}` + need to match the datasets. The format of proposal files are documented + [here](../modules/data.html#detectron2.data.load_proposals_into_dataset). + +New models +(e.g. [TensorMask](../../projects/TensorMask), +[PointRend](../../projects/PointRend)) +often have similar configs of their own that need to be changed as well. + +```eval_rst +.. tip:: + + After changing the number of classes, certain layers in a pre-trained model will become incompatible + and therefore cannot be loaded to the new model. + This is expected, and loading such pre-trained models will produce warnings about such layers. +``` diff --git a/data_processing/detectron2/docs/tutorials/deployment.md b/data_processing/detectron2/docs/tutorials/deployment.md new file mode 100644 index 0000000..f759888 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/deployment.md @@ -0,0 +1,137 @@ +# Deployment + +Models written in Python need to go through an export process to become a deployable artifact. +A few basic concepts about this process: + +__"Export method"__ is how a Python model is fully serialized to a deployable format. +We support the following export methods: + +* `tracing`: see [pytorch documentation](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) to learn about it +* `scripting`: see [pytorch documentation](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) to learn about it +* `caffe2_tracing`: replace parts of the model by caffe2 operators, then use tracing. + +__"Format"__ is how a serialized model is described in a file, e.g. +TorchScript, Caffe2 protobuf, ONNX format. +__"Runtime"__ is an engine that loads a serialized model and executes it, +e.g., PyTorch, Caffe2, TensorFlow, onnxruntime, TensorRT, etc. +A runtime is often tied to a specific format +(e.g. PyTorch needs TorchScript format, Caffe2 needs protobuf format). +We currently support the following combination and each has some limitations: + +```eval_rst ++----------------------------+-------------+-------------+-----------------------------+ +| Export Method | tracing | scripting | caffe2_tracing | ++============================+=============+=============+=============================+ +| **Formats** | TorchScript | TorchScript | Caffe2, TorchScript, ONNX | ++----------------------------+-------------+-------------+-----------------------------+ +| **Runtime** | PyTorch | PyTorch | Caffe2, PyTorch | ++----------------------------+-------------+-------------+-----------------------------+ +| C++/Python inference | ✅ | ✅ | ✅ | ++----------------------------+-------------+-------------+-----------------------------+ +| Dynamic resolution | ✅ | ✅ | ✅ | ++----------------------------+-------------+-------------+-----------------------------+ +| Batch size requirement | Constant | Dynamic | Batch inference unsupported | ++----------------------------+-------------+-------------+-----------------------------+ +| Extra runtime deps | torchvision | torchvision | Caffe2 ops (usually already | +| | | | | +| | | | included in PyTorch) | ++----------------------------+-------------+-------------+-----------------------------+ +| Faster/Mask/Keypoint R-CNN | ✅ | ✅ | ✅ | ++----------------------------+-------------+-------------+-----------------------------+ +| RetinaNet | ✅ | ✅ | ✅ | ++----------------------------+-------------+-------------+-----------------------------+ +| PointRend R-CNN | ✅ | ❌ | ❌ | ++----------------------------+-------------+-------------+-----------------------------+ +| Cascade R-CNN | ✅ | ❌ | ❌ | ++----------------------------+-------------+-------------+-----------------------------+ + +``` + +`caffe2_tracing` is going to be deprecated. +We don't plan to work on additional support for other formats/runtime, but contributions are welcome. + + +## Deployment with Tracing or Scripting + +Models can be exported to TorchScript format, by either +[tracing or scripting](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html). +The output model file can be loaded without detectron2 dependency in either Python or C++. +The exported model often requires torchvision (or its C++ library) dependency for some custom ops. + +This feature requires PyTorch ≥ 1.8. + +### Coverage +Most official models under the meta architectures `GeneralizedRCNN` and `RetinaNet` +are supported in both tracing and scripting mode. +Cascade R-CNN and PointRend are currently supported in tracing. +Users' custom extensions are supported if they are also scriptable or traceable. + +For models exported with tracing, dynamic input resolution is allowed, but batch size +(number of input images) must be fixed. +Scripting can support dynamic batch size. + +### Usage + +The main export APIs for tracing and scripting are [TracingAdapter](../modules/export.html#detectron2.export.TracingAdapter) +and [scripting_with_instances](../modules/export.html#detectron2.export.scripting_with_instances). +Their usage is currently demonstrated in [test_export_torchscript.py](../../tests/test_export_torchscript.py) +(see `TestScripting` and `TestTracing`) +as well as the [deployment example](../../tools/deploy). +Please check that these examples can run, and then modify for your use cases. +The usage now requires some user effort and necessary knowledge for each model to workaround the limitation of scripting and tracing. +In the future we plan to wrap these under simpler APIs to lower the bar to use them. + +## Deployment with Caffe2-tracing +We provide [Caffe2Tracer](../modules/export.html#detectron2.export.Caffe2Tracer) +that performs the export logic. +It replaces parts of the model with Caffe2 operators, +and then export the model into Caffe2, TorchScript or ONNX format. + +The converted model is able to run in either Python or C++ without detectron2/torchvision dependency, on CPU or GPUs. +It has a runtime optimized for CPU & mobile inference, but not optimized for GPU inference. + +This feature requires ONNX ≥ 1.6. + +### Coverage + +Most official models under these 3 common meta architectures: `GeneralizedRCNN`, `RetinaNet`, `PanopticFPN` +are supported. Cascade R-CNN is not supported. Batch inference is not supported. + +Users' custom extensions under these architectures (added through registration) are supported +as long as they do not contain control flow or operators not available in Caffe2 (e.g. deformable convolution). +For example, custom backbones and heads are often supported out of the box. + +### Usage + +The APIs are listed at [the API documentation](../modules/export). +We provide [export_model.py](../../tools/deploy/) as an example that uses +these APIs to convert a standard model. For custom models/datasets, you can add them to this script. + +### Use the model in C++/Python + +The model can be loaded in C++ and deployed with +either Caffe2 or Pytorch runtime.. [C++ examples](../../tools/deploy/) for Mask R-CNN +are given as a reference. Note that: + +* Models exported with `caffe2_tracing` method take a special input format + described in [documentation](../modules/export.html#detectron2.export.Caffe2Tracer). + This was taken care of in the C++ example. + +* The converted models do not contain post-processing operations that + transform raw layer outputs into formatted predictions. + For example, the C++ examples only produce raw outputs (28x28 masks) from the final + layers that are not post-processed, because in actual deployment, an application often needs + its custom lightweight post-processing, so this step is left for users. + +To help use the Caffe2-format model in python, +we provide a python wrapper around the converted model, in the +[Caffe2Model.\_\_call\_\_](../modules/export.html#detectron2.export.Caffe2Model.__call__) method. +This method has an interface that's identical to the [pytorch versions of models](./models.md), +and it internally applies pre/post-processing code to match the formats. +This wrapper can serve as a reference for how to use Caffe2's python API, +or for how to implement pre/post-processing in actual deployment. + +## Conversion to TensorFlow +[tensorpack Faster R-CNN](https://github.com/tensorpack/tensorpack/tree/master/examples/FasterRCNN/convert_d2) +provides scripts to convert a few standard detectron2 R-CNN models to TensorFlow's pb format. +It works by translating configs and weights, therefore only support a few models. diff --git a/data_processing/detectron2/docs/tutorials/evaluation.md b/data_processing/detectron2/docs/tutorials/evaluation.md new file mode 100644 index 0000000..2ef94fa --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/evaluation.md @@ -0,0 +1,68 @@ + +# Evaluation + +Evaluation is a process that takes a number of inputs/outputs pairs and aggregate them. +You can always [use the model](./models.md) directly and just parse its inputs/outputs manually to perform +evaluation. +Alternatively, evaluation is implemented in detectron2 using the [DatasetEvaluator](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluator) +interface. + +Detectron2 includes a few `DatasetEvaluator` that computes metrics using standard dataset-specific +APIs (e.g., COCO, LVIS). +You can also implement your own `DatasetEvaluator` that performs some other jobs +using the inputs/outputs pairs. +For example, to count how many instances are detected on the validation set: + +```python +class Counter(DatasetEvaluator): + def reset(self): + self.count = 0 + def process(self, inputs, outputs): + for output in outputs: + self.count += len(output["instances"]) + def evaluate(self): + # save self.count somewhere, or print it, or return it. + return {"count": self.count} +``` + +## Use evaluators + +To evaluate using the methods of evaluators manually: +```python +def get_all_inputs_outputs(): + for data in data_loader: + yield data, model(data) + +evaluator.reset() +for inputs, outputs in get_all_inputs_outputs(): + evaluator.process(inputs, outputs) +eval_results = evaluator.evaluate() +``` + +Evaluators can also be used with [inference_on_dataset](../modules/evaluation.html#detectron2.evaluation.inference_on_dataset). +For example, + +```python +eval_results = inference_on_dataset( + model, + data_loader, + DatasetEvaluators([COCOEvaluator(...), Counter()])) +``` +This will execute `model` on all inputs from `data_loader`, and call evaluator to process them. + +Compared to running the evaluation manually using the model, the benefit of this function is that +evaluators can be merged together using [DatasetEvaluators](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluators), +and all the evaluation can finish in one forward pass over the dataset. +This function also provides accurate speed benchmarks for the given model and dataset. + +## Evaluators for custom dataset + +Many evaluators in detectron2 are made for specific datasets, +in order to obtain scores using each dataset's official API. +In addition to that, two evaluators are able to evaluate any generic dataset +that follows detectron2's [standard dataset format](./datasets.md), so they +can be used to evaluate custom datasets: + +* [COCOEvaluator](../modules/evaluation.html#detectron2.evaluation.COCOEvaluator) is able to evaluate AP (Average Precision) for box detection, + instance segmentation, keypoint detection on any custom dataset. +* [SemSegEvaluator](../modules/evaluation.html#detectron2.evaluation.SemSegEvaluator) is able to evaluate semantic segmentation metrics on any custom dataset. diff --git a/data_processing/detectron2/docs/tutorials/extend.md b/data_processing/detectron2/docs/tutorials/extend.md new file mode 100644 index 0000000..a6af550 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/extend.md @@ -0,0 +1,141 @@ +# Extend Detectron2's Defaults + +__Research is about doing things in new ways__. +This brings a tension in how to create abstractions in code, +which is a challenge for any research engineering project of a significant size: + +1. On one hand, it needs to have very thin abstractions to allow for the possibility of doing + everything in new ways. It should be reasonably easy to break existing + abstractions and replace them with new ones. + +2. On the other hand, such a project also needs reasonably high-level + abstractions, so that users can easily do things in standard ways, + without worrying too much about the details that only certain researchers care about. + +In detectron2, there are two types of interfaces that address this tension together: + +1. Functions and classes that take a config (`cfg`) argument + created from a yaml file + (sometimes with few extra arguments). + + Such functions and classes implement + the "standard default" behavior: it will read what it needs from a given + config and do the "standard" thing. + Users only need to load an expert-made config and pass it around, without having to worry about + which arguments are used and what they all mean. + + See [Yacs Configs](configs.md) for a detailed tutorial. + +2. Functions and classes that have well-defined explicit arguments. + + Each of these is a small building block of the entire system. + They require users' expertise to understand what each argument should be, + and require more effort to stitch together to a larger system. + But they can be stitched together in more flexible ways. + + When you need to implement something not supported by the "standard defaults" + included in detectron2, these well-defined components can be reused. + + The [LazyConfig system](lazyconfigs.md) relies on such functions and classes. + +3. A few functions and classes are implemented with the + [@configurable](../modules/config.html#detectron2.config.configurable) + decorator - they can be called with either a config, or with explicit arguments, or a mixture of both. + Their explicit argument interfaces are currently experimental. + + As an example, a Mask R-CNN model can be built in the following ways: + + 1. Config-only: + ```python + # load proper yaml config file, then + model = build_model(cfg) + ``` + + 2. Mixture of config and additional argument overrides: + ```python + model = GeneralizedRCNN( + cfg, + roi_heads=StandardROIHeads(cfg, batch_size_per_image=666), + pixel_std=[57.0, 57.0, 57.0]) + ``` + + 3. Full explicit arguments: +
+ + (click to expand) + + + ```python + model = GeneralizedRCNN( + backbone=FPN( + ResNet( + BasicStem(3, 64, norm="FrozenBN"), + ResNet.make_default_stages(50, stride_in_1x1=True, norm="FrozenBN"), + out_features=["res2", "res3", "res4", "res5"], + ).freeze(2), + ["res2", "res3", "res4", "res5"], + 256, + top_block=LastLevelMaxPool(), + ), + proposal_generator=RPN( + in_features=["p2", "p3", "p4", "p5", "p6"], + head=StandardRPNHead(in_channels=256, num_anchors=3), + anchor_generator=DefaultAnchorGenerator( + sizes=[[32], [64], [128], [256], [512]], + aspect_ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + offset=0.0, + ), + anchor_matcher=Matcher([0.3, 0.7], [0, -1, 1], allow_low_quality_matches=True), + box2box_transform=Box2BoxTransform([1.0, 1.0, 1.0, 1.0]), + batch_size_per_image=256, + positive_fraction=0.5, + pre_nms_topk=(2000, 1000), + post_nms_topk=(1000, 1000), + nms_thresh=0.7, + ), + roi_heads=StandardROIHeads( + num_classes=80, + batch_size_per_image=512, + positive_fraction=0.25, + proposal_matcher=Matcher([0.5], [0, 1], allow_low_quality_matches=False), + box_in_features=["p2", "p3", "p4", "p5"], + box_pooler=ROIPooler(7, (1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 0, "ROIAlignV2"), + box_head=FastRCNNConvFCHead( + ShapeSpec(channels=256, height=7, width=7), conv_dims=[], fc_dims=[1024, 1024] + ), + box_predictor=FastRCNNOutputLayers( + ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=Box2BoxTransform((10, 10, 5, 5)), + num_classes=80, + ), + mask_in_features=["p2", "p3", "p4", "p5"], + mask_pooler=ROIPooler(14, (1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 0, "ROIAlignV2"), + mask_head=MaskRCNNConvUpsampleHead( + ShapeSpec(channels=256, width=14, height=14), + num_classes=80, + conv_dims=[256, 256, 256, 256, 256], + ), + ), + pixel_mean=[103.530, 116.280, 123.675], + pixel_std=[1.0, 1.0, 1.0], + input_format="BGR", + ) + ``` + +
+ + +If you only need the standard behavior, the [Beginner's Tutorial](./getting_started.md) +should suffice. If you need to extend detectron2 to your own needs, +see the following tutorials for more details: + +* Detectron2 includes a few standard datasets. To use custom ones, see + [Use Custom Datasets](./datasets.md). +* Detectron2 contains the standard logic that creates a data loader for training/testing from a + dataset, but you can write your own as well. See [Use Custom Data Loaders](./data_loading.md). +* Detectron2 implements many standard detection models, and provide ways for you + to overwrite their behaviors. See [Use Models](./models.md) and [Write Models](./write-models.md). +* Detectron2 provides a default training loop that is good for common training tasks. + You can customize it with hooks, or write your own loop instead. See [training](./training.md). diff --git a/data_processing/detectron2/docs/tutorials/getting_started.md b/data_processing/detectron2/docs/tutorials/getting_started.md new file mode 100644 index 0000000..e90bde7 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/getting_started.md @@ -0,0 +1 @@ +../../GETTING_STARTED.md \ No newline at end of file diff --git a/data_processing/detectron2/docs/tutorials/index.rst b/data_processing/detectron2/docs/tutorials/index.rst new file mode 100644 index 0000000..850b95c --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/index.rst @@ -0,0 +1,20 @@ +Tutorials +====================================== + +.. toctree:: + :maxdepth: 2 + + install + getting_started + builtin_datasets + extend + datasets + data_loading + augmentation + models + write-models + training + evaluation + configs + lazyconfigs + deployment diff --git a/data_processing/detectron2/docs/tutorials/install.md b/data_processing/detectron2/docs/tutorials/install.md new file mode 100644 index 0000000..5f52b2b --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/install.md @@ -0,0 +1 @@ +../../INSTALL.md \ No newline at end of file diff --git a/data_processing/detectron2/docs/tutorials/lazyconfigs.md b/data_processing/detectron2/docs/tutorials/lazyconfigs.md new file mode 100644 index 0000000..a01101a --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/lazyconfigs.md @@ -0,0 +1,170 @@ +# Lazy Configs + +The traditional yacs-based config system provides basic, standard functionalities. +However, it does not offer enough flexibility for many new projects. +We develop an alternative, non-intrusive config system that can be used with +detectron2 or potentially any other complex projects. + +## Python Syntax + +Our config objects are still dictionaries. Instead of using Yaml to define dictionaries, +we create dictionaries in Python directly. This gives users the following power that +doesn't exist in Yaml: + +* Easily manipulate the dictionary (addition & deletion) using Python. +* Write simple arithmetics or call simple functions. +* Use more data types / objects. +* Import / compose other config files, using the familiar Python import syntax. + +A Python config file can be loaded like this: +```python +# config.py: +a = dict(x=1, y=2, z=dict(xx=1)) +b = dict(x=3, y=4) + +# my_code.py: +from detectron2.config import LazyConfig +cfg = LazyConfig.load("path/to/config.py") # an omegaconf dictionary +assert cfg.a.z.xx == 1 +``` + +After [LazyConfig.load](../modules/config.html#detectron2.config.LazyConfig.load), `cfg` will be a dictionary that contains all dictionaries +defined in the global scope of the config file. Note that: +* All dictionaries are turned to an [omegaconf](https://omegaconf.readthedocs.io/) + config object during loading. This enables access to omegaconf features, + such as its [access syntax](https://omegaconf.readthedocs.io/en/2.1_branch/usage.html#access-and-manipulation) + and [interpolation](https://omegaconf.readthedocs.io/en/2.1_branch/usage.html#variable-interpolation). +* Absolute imports in `config.py` works the same as in regular Python. +* Relative imports can only import dictionaries from config files. + They are simply a syntax sugar for [LazyConfig.load_rel](../modules/config.html#detectron2.config.LazyConfig.load_rel). + They can load Python files at relative path without requiring `__init__.py`. + +[LazyConfig.save](../modules/config.html#detectron2.config.LazyConfig.save) can save a config object to yaml. +Note that this is not always successful if non-serializable objects appear in the config file (e.g. lambdas). +It is up to users whether to sacrifice the ability to save in exchange for flexibility. + +## Recursive Instantiation + +The LazyConfig system heavily uses recursive instantiation, which is a pattern that +uses a dictionary to describe a +call to a function/class. The dictionary consists of: + +1. A "\_target\_" key which contains path to the callable, such as "module.submodule.class_name". +2. Other keys that represent arguments to pass to the callable. Arguments themselves can be defined + using recursive instantiation. + +We provide a helper function [LazyCall](../modules/config.html#detectron2.config.LazyCall) that helps create such dictionaries. +The following code using `LazyCall` +```python +from detectron2.config import LazyCall as L +from my_app import Trainer, Optimizer +cfg = L(Trainer)( + optimizer=L(Optimizer)( + lr=0.01, + algo="SGD" + ) +) +``` +creates a dictionary like this: +```python +cfg = { + "_target_": "my_app.Trainer", + "optimizer": { + "_target_": "my_app.Optimizer", + "lr": 0.01, "algo": "SGD" + } +} +``` + +By representing objects using such dictionaries, a general +[instantiate](../modules/config.html#detectron2.config.instantiate) +function can turn them into actual objects, i.e.: +```python +from detectron2.config import instantiate +trainer = instantiate(cfg) +# equivalent to: +# from my_app import Trainer, Optimizer +# trainer = Trainer(optimizer=Optimizer(lr=0.01, algo="SGD")) +``` + +This pattern is powerful enough to describe very complex objects, e.g.: + +
+ +A Full Mask R-CNN described in recursive instantiation (click to expand) + + +```eval_rst +.. literalinclude:: ../../configs/common/models/mask_rcnn_fpn.py + :language: python + :linenos: +``` + +
+ +There are also objects or logic that cannot be described simply by a dictionary, +such as reused objects or method calls. They may require some refactoring +to work with recursive instantiation. + +## Using Model Zoo LazyConfigs + +We provide some configs in the model zoo using the LazyConfig system, for example: + +* [common baselines](../../configs/common/). +* [new Mask R-CNN baselines](../../configs/new_baselines/) + +After installing detectron2, they can be loaded by the model zoo API +[model_zoo.get_config](../modules/model_zoo.html#detectron2.model_zoo.get_config). + +Using these as references, you're free to define custom config structure / fields for your own +project, as long as your training script can understand them. +Despite of this, our model zoo configs still follow some simple conventions for consistency, e.g. +`cfg.model` defines a model object, `cfg.dataloader.{train,test}` defines dataloader objects, +and `cfg.train` contains training options in key-value form. +In addition to `print()`, a better way to view the structure of a config is like this: +```python +from detectron2.model_zoo import get_config +from detectron2.config import LazyConfig +print(LazyConfig.to_py(get_config("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py"))) +``` +From the output it's easier to find relevant options to change, e.g. +`dataloader.train.total_batch_size` for the batch size, or `optimizer.lr` for base learning rate. + +We provide a reference training script +[tools/lazyconfig_train_net.py](../../tools/lazyconfig_train_net.py), +that can train/eval our model zoo configs. +It also shows how to support command line value overrides. + +To demonstrate the power and flexibility of the new system, we show that +[a simple config file](../../configs/Misc/torchvision_imagenet_R_50.py) +can let detectron2 train an ImageNet classification model from torchvision, even though +detectron2 contains no features about ImageNet classification. +This can serve as a reference for using detectron2 in other deep learning tasks. + +## Summary + +By using recursive instantiation to create objects, +we avoid passing a giant config to many places, because `cfg` is only passed to `instantiate`. +This has the following benefits: + +* It's __non-intrusive__: objects to be constructed are config-agnostic, regular Python + functions/classes. + They can even live in other libraries. For example, + `{"_target_": "torch.nn.Conv2d", "in_channels": 10, "out_channels": 10, "kernel_size": 1}` + defines a conv layer. +* __Clarity__ of what function/classes will be called, and what arguments they use. +* `cfg` doesn't need pre-defined keys and structures. It's valid as long as it translates to valid + code. This gives a lot more __flexibility__. +* You can still pass huge dictionaries as arguments, just like the old way. + +Recursive instantiation and Python syntax are orthogonal: you can use one without the other. +But by putting them together, the config file looks a lot like the code that will be executed: + +![img](./lazyconfig.jpg) + +However, the config file just defines dictionaries, which can be easily manipulated further +by composition or overrides. +The corresponding code will only be executed +later when `instantiate` is called. In some way, +in config files we're writing "editable code" that will be "lazily executed" later when needed. +That's why we call this system "LazyConfig". diff --git a/data_processing/detectron2/docs/tutorials/models.md b/data_processing/detectron2/docs/tutorials/models.md new file mode 100644 index 0000000..a2def5c --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/models.md @@ -0,0 +1,180 @@ +# Use Models + +## Build Models from Yacs Config +From a yacs config object, +models (and their sub-models) can be built by +functions such as `build_model`, `build_backbone`, `build_roi_heads`: +```python +from detectron2.modeling import build_model +model = build_model(cfg) # returns a torch.nn.Module +``` + +`build_model` only builds the model structure and fills it with random parameters. +See below for how to load an existing checkpoint to the model and how to use the `model` object. + +### Load/Save a Checkpoint +```python +from detectron2.checkpoint import DetectionCheckpointer +DetectionCheckpointer(model).load(file_path_or_url) # load a file, usually from cfg.MODEL.WEIGHTS + +checkpointer = DetectionCheckpointer(model, save_dir="output") +checkpointer.save("model_999") # save to output/model_999.pth +``` + +Detectron2's checkpointer recognizes models in pytorch's `.pth` format, as well as the `.pkl` files +in our model zoo. +See [API doc](../modules/checkpoint.html#detectron2.checkpoint.DetectionCheckpointer) +for more details about its usage. + +The model files can be arbitrarily manipulated using `torch.{load,save}` for `.pth` files or +`pickle.{dump,load}` for `.pkl` files. + +### Use a Model + +A model can be called by `outputs = model(inputs)`, where `inputs` is a `list[dict]`. +Each dict corresponds to one image and the required keys +depend on the type of model, and whether the model is in training or evaluation mode. +For example, in order to do inference, +all existing models expect the "image" key, and optionally "height" and "width". +The detailed format of inputs and outputs of existing models are explained below. + +__Training__: When in training mode, all models are required to be used under an `EventStorage`. +The training statistics will be put into the storage: +```python +from detectron2.utils.events import EventStorage +with EventStorage() as storage: + losses = model(inputs) +``` + +__Inference__: If you only want to do simple inference using an existing model, +[DefaultPredictor](../modules/engine.html#detectron2.engine.defaults.DefaultPredictor) +is a wrapper around model that provides such basic functionality. +It includes default behavior including model loading, preprocessing, +and operates on single image rather than batches. See its documentation for usage. + +You can also run inference directly like this: +```python +model.eval() +with torch.no_grad(): + outputs = model(inputs) +``` + +### Model Input Format + +Users can implement custom models that support any arbitrary input format. +Here we describe the standard input format that all builtin models support in detectron2. +They all take a `list[dict]` as the inputs. Each dict +corresponds to information about one image. + +The dict may contain the following keys: + +* "image": `Tensor` in (C, H, W) format. The meaning of channels are defined by `cfg.INPUT.FORMAT`. + Image normalization, if any, will be performed inside the model using + `cfg.MODEL.PIXEL_{MEAN,STD}`. +* "height", "width": the **desired** output height and width **in inference**, which is not necessarily the same + as the height or width of the `image` field. + For example, the `image` field contains the resized image, if resize is used as a preprocessing step. + But you may want the outputs to be in **original** resolution. + If provided, the model will produce output in this resolution, + rather than in the resolution of the `image` as input into the model. This is more efficient and accurate. +* "instances": an [Instances](../modules/structures.html#detectron2.structures.Instances) + object for training, with the following fields: + + "gt_boxes": a [Boxes](../modules/structures.html#detectron2.structures.Boxes) object storing N boxes, one for each instance. + + "gt_classes": `Tensor` of long type, a vector of N labels, in range [0, num_categories). + + "gt_masks": a [PolygonMasks](../modules/structures.html#detectron2.structures.PolygonMasks) + or [BitMasks](../modules/structures.html#detectron2.structures.BitMasks) object storing N masks, one for each instance. + + "gt_keypoints": a [Keypoints](../modules/structures.html#detectron2.structures.Keypoints) + object storing N keypoint sets, one for each instance. +* "sem_seg": `Tensor[int]` in (H, W) format. The semantic segmentation ground truth for training. + Values represent category labels starting from 0. +* "proposals": an [Instances](../modules/structures.html#detectron2.structures.Instances) + object used only in Fast R-CNN style models, with the following fields: + + "proposal_boxes": a [Boxes](../modules/structures.html#detectron2.structures.Boxes) object storing P proposal boxes. + + "objectness_logits": `Tensor`, a vector of P scores, one for each proposal. + +For inference of builtin models, only "image" key is required, and "width/height" are optional. + +We currently don't define standard input format for panoptic segmentation training, +because models now use custom formats produced by custom data loaders. + +#### How it connects to data loader: + +The output of the default [DatasetMapper]( ../modules/data.html#detectron2.data.DatasetMapper) is a dict +that follows the above format. +After the data loader performs batching, it becomes `list[dict]` which the builtin models support. + + +### Model Output Format + +When in training mode, the builtin models output a `dict[str->ScalarTensor]` with all the losses. + +When in inference mode, the builtin models output a `list[dict]`, one dict for each image. +Based on the tasks the model is doing, each dict may contain the following fields: + +* "instances": [Instances](../modules/structures.html#detectron2.structures.Instances) + object with the following fields: + * "pred_boxes": [Boxes](../modules/structures.html#detectron2.structures.Boxes) object storing N boxes, one for each detected instance. + * "scores": `Tensor`, a vector of N confidence scores. + * "pred_classes": `Tensor`, a vector of N labels in range [0, num_categories). + + "pred_masks": a `Tensor` of shape (N, H, W), masks for each detected instance. + + "pred_keypoints": a `Tensor` of shape (N, num_keypoint, 3). + Each row in the last dimension is (x, y, score). Confidence scores are larger than 0. +* "sem_seg": `Tensor` of (num_categories, H, W), the semantic segmentation prediction. +* "proposals": [Instances](../modules/structures.html#detectron2.structures.Instances) + object with the following fields: + * "proposal_boxes": [Boxes](../modules/structures.html#detectron2.structures.Boxes) + object storing N boxes. + * "objectness_logits": a torch vector of N confidence scores. +* "panoptic_seg": A tuple of `(pred: Tensor, segments_info: Optional[list[dict]])`. + The `pred` tensor has shape (H, W), containing the segment id of each pixel. + + * If `segments_info` exists, each dict describes one segment id in `pred` and has the following fields: + + * "id": the segment id + * "isthing": whether the segment is a thing or stuff + * "category_id": the category id of this segment. + + If a pixel's id does not exist in `segments_info`, it is considered to be void label + defined in [Panoptic Segmentation](https://arxiv.org/abs/1801.00868). + + * If `segments_info` is None, all pixel values in `pred` must be ≥ -1. + Pixels with value -1 are assigned void labels. + Otherwise, the category id of each pixel is obtained by + `category_id = pixel // metadata.label_divisor`. + + +### Partially execute a model: + +Sometimes you may want to obtain an intermediate tensor inside a model, +such as the input of certain layer, the output before post-processing. +Since there are typically hundreds of intermediate tensors, there isn't an API that provides you +the intermediate result you need. +You have the following options: + +1. Write a (sub)model. Following the [tutorial](./write-models.md), you can + rewrite a model component (e.g. a head of a model), such that it + does the same thing as the existing component, but returns the output + you need. +2. Partially execute a model. You can create the model as usual, + but use custom code to execute it instead of its `forward()`. For example, + the following code obtains mask features before mask head. + + ```python + images = ImageList.from_tensors(...) # preprocessed input tensor + model = build_model(cfg) + model.eval() + features = model.backbone(images.tensor) + proposals, _ = model.proposal_generator(images, features) + instances, _ = model.roi_heads(images, features, proposals) + mask_features = [features[f] for f in model.roi_heads.in_features] + mask_features = model.roi_heads.mask_pooler(mask_features, [x.pred_boxes for x in instances]) + ``` + +3. Use [forward hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks). + Forward hooks can help you obtain inputs or outputs of a certain module. + If they are not exactly what you want, they can at least be used together with partial execution + to obtain other tensors. + +All options require you to read documentation and sometimes code +of the existing models to understand the internal logic, +in order to write code to obtain the internal tensors. diff --git a/data_processing/detectron2/docs/tutorials/training.md b/data_processing/detectron2/docs/tutorials/training.md new file mode 100644 index 0000000..83a6cb0 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/training.md @@ -0,0 +1,67 @@ +# Training + +From the previous tutorials, you may now have a custom model and a data loader. +To run training, users typically have a preference in one of the following two styles: + +### Custom Training Loop + +With a model and a data loader ready, everything else needed to write a training loop can +be found in PyTorch, and you are free to write the training loop yourself. +This style allows researchers to manage the entire training logic more clearly and have full control. +One such example is provided in [tools/plain_train_net.py](../../tools/plain_train_net.py). + +Any customization on the training logic is then easily controlled by the user. + +### Trainer Abstraction + +We also provide a standardized "trainer" abstraction with a +hook system that helps simplify the standard training behavior. +It includes the following two instantiations: + +* [SimpleTrainer](../modules/engine.html#detectron2.engine.SimpleTrainer) + provides a minimal training loop for single-cost single-optimizer single-data-source training, with nothing else. + Other tasks (checkpointing, logging, etc) can be implemented using + [the hook system](../modules/engine.html#detectron2.engine.HookBase). +* [DefaultTrainer](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer) is a `SimpleTrainer` initialized from a + yacs config, used by + [tools/train_net.py](../../tools/train_net.py) and many scripts. + It includes more standard default behaviors that one might want to opt in, + including default configurations for optimizer, learning rate schedule, + logging, evaluation, checkpointing etc. + +To customize a `DefaultTrainer`: + +1. For simple customizations (e.g. change optimizer, evaluator, LR scheduler, data loader, etc.), overwrite [its methods](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer) in a subclass, just like [tools/train_net.py](../../tools/train_net.py). +2. For extra tasks during training, check the + [hook system](../modules/engine.html#detectron2.engine.HookBase) to see if it's supported. + + As an example, to print hello during training: + ```python + class HelloHook(HookBase): + def after_step(self): + if self.trainer.iter % 100 == 0: + print(f"Hello at iteration {self.trainer.iter}!") + ``` +3. Using a trainer+hook system means there will always be some non-standard behaviors that cannot be supported, especially in research. + For this reason, we intentionally keep the trainer & hook system minimal, rather than powerful. + If anything cannot be achieved by such a system, it's easier to start from [tools/plain_train_net.py](../../tools/plain_train_net.py) to implement custom training logic manually. + +### Logging of Metrics + +During training, detectron2 models and trainer put metrics to a centralized [EventStorage](../modules/utils.html#detectron2.utils.events.EventStorage). +You can use the following code to access it and log metrics to it: +```python +from detectron2.utils.events import get_event_storage + +# inside the model: +if self.training: + value = # compute the value from inputs + storage = get_event_storage() + storage.put_scalar("some_accuracy", value) +``` + +Refer to its documentation for more details. + +Metrics are then written to various destinations with [EventWriter](../modules/utils.html#module-detectron2.utils.events). +DefaultTrainer enables a few `EventWriter` with default configurations. +See above for how to customize them. diff --git a/data_processing/detectron2/docs/tutorials/write-models.md b/data_processing/detectron2/docs/tutorials/write-models.md new file mode 100644 index 0000000..967d126 --- /dev/null +++ b/data_processing/detectron2/docs/tutorials/write-models.md @@ -0,0 +1,90 @@ +# Write Models + +If you are trying to do something completely new, you may wish to implement +a model entirely from scratch. However, in many situations you may +be interested in modifying or extending some components of an existing model. +Therefore, we also provide mechanisms that let users override the +behavior of certain internal components of standard models. + + +## Register New Components + +For common concepts that users often want to customize, such as "backbone feature extractor", "box head", +we provide a registration mechanism for users to inject custom implementation that +will be immediately available to use in config files. + +For example, to add a new backbone, import this code in your code: +```python +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + +@BACKBONE_REGISTRY.register() +class ToyBackbone(Backbone): + def __init__(self, cfg, input_shape): + super().__init__() + # create your own backbone + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=16, padding=3) + + def forward(self, image): + return {"conv1": self.conv1(image)} + + def output_shape(self): + return {"conv1": ShapeSpec(channels=64, stride=16)} +``` + +In this code, we implement a new backbone following the interface of the +[Backbone](../modules/modeling.html#detectron2.modeling.Backbone) class, +and register it into the [BACKBONE_REGISTRY](../modules/modeling.html#detectron2.modeling.BACKBONE_REGISTRY) +which requires subclasses of `Backbone`. +After importing this code, detectron2 can link the name of the class to its implementation. Therefore you can write the following code: + +```python +cfg = ... # read a config +cfg.MODEL.BACKBONE.NAME = 'ToyBackbone' # or set it in the config file +model = build_model(cfg) # it will find `ToyBackbone` defined above +``` + +As another example, to add new abilities to the ROI heads in the Generalized R-CNN meta-architecture, +you can implement a new +[ROIHeads](../modules/modeling.html#detectron2.modeling.ROIHeads) subclass and put it in the `ROI_HEADS_REGISTRY`. +[DensePose](../../projects/DensePose) +and [MeshRCNN](https://github.com/facebookresearch/meshrcnn) +are two examples that implement new ROIHeads to perform new tasks. +And [projects/](../../projects/) +contains more examples that implement different architectures. + +A complete list of registries can be found in [API documentation](../modules/modeling.html#model-registries). +You can register components in these registries to customize different parts of a model, or the +entire model. + +## Construct Models with Explicit Arguments + +Registry is a bridge to connect names in config files to the actual code. +They are meant to cover a few main components that users frequently need to replace. +However, the capability of a text-based config file is sometimes limited and +some deeper customization may be available only through writing code. + +Most model components in detectron2 have a clear `__init__` interface that documents +what input arguments it needs. Calling them with custom arguments will give you a custom variant +of the model. + +As an example, to use __custom loss function__ in the box head of a Faster R-CNN, we can do the following: + +1. Losses are currently computed in [FastRCNNOutputLayers](../modules/modeling.html#detectron2.modeling.FastRCNNOutputLayers). + We need to implement a variant or a subclass of it, with custom loss functions, named `MyRCNNOutput`. +2. Call `StandardROIHeads` with `box_predictor=MyRCNNOutput()` argument instead of the builtin `FastRCNNOutputLayers`. + If all other arguments should stay unchanged, this can be easily achieved by using the [configurable `__init__`](../modules/config.html#detectron2.config.configurable) mechanism: + + ```python + roi_heads = StandardROIHeads( + cfg, backbone.output_shape(), + box_predictor=MyRCNNOutput(...) + ) + ``` +3. (optional) If we want to enable this new model from a config file, registration is needed: + ```python + @ROI_HEADS_REGISTRY.register() + class MyStandardROIHeads(StandardROIHeads): + def __init__(self, cfg, input_shape): + super().__init__(cfg, input_shape, + box_predictor=MyRCNNOutput(...)) + ``` diff --git a/data_processing/detectron2/projects/DeepLab/README.md b/data_processing/detectron2/projects/DeepLab/README.md new file mode 100644 index 0000000..bd03cf1 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/README.md @@ -0,0 +1,100 @@ +# DeepLab in Detectron2 + +In this repository, we implement DeepLabV3 and DeepLabV3+ in Detectron2. + +## Installation +Install Detectron2 following [the instructions](https://detectron2.readthedocs.io/tutorials/install.html). + +## Training + +To train a model with 8 GPUs run: +```bash +cd /path/to/detectron2/projects/DeepLab +python train_net.py --config-file configs/Cityscapes-SemanticSegmentation/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly: +```bash +cd /path/to/detectron2/projects/DeepLab +python train_net.py --config-file configs/Cityscapes-SemanticSegmentation/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint +``` + +## Cityscapes Semantic Segmentation +Cityscapes models are trained with ImageNet pretraining. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodBackboneOutput
resolution
mIoUmodel iddownload
DeepLabV3R101-DC51024×2048 76.7 - -  |  -
DeepLabV3R103-DC51024×2048 78.5 28041665 model | metrics
DeepLabV3+R101-DC51024×2048 78.1 - -  |  -
DeepLabV3+R103-DC51024×2048 80.0 28054032model | metrics
+ +Note: +- [R103](https://dl.fbaipublicfiles.com/detectron2/DeepLab/R-103.pkl): a ResNet-101 with its first 7x7 convolution replaced by 3 3x3 convolutions. +This modification has been used in most semantic segmentation papers. We pre-train this backbone on ImageNet using the default recipe of [pytorch examples](https://github.com/pytorch/examples/tree/master/imagenet). +- DC5 means using dilated convolution in `res5`. + +## Citing DeepLab + +If you use DeepLab, please use the following BibTeX entry. + +* DeepLabv3+: + +``` +@inproceedings{deeplabv3plus2018, + title={Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation}, + author={Liang-Chieh Chen and Yukun Zhu and George Papandreou and Florian Schroff and Hartwig Adam}, + booktitle={ECCV}, + year={2018} +} +``` + +* DeepLabv3: + +``` +@article{deeplabv32018, + title={Rethinking atrous convolution for semantic image segmentation}, + author={Chen, Liang-Chieh and Papandreou, George and Schroff, Florian and Adam, Hartwig}, + journal={arXiv:1706.05587}, + year={2017} +} +``` diff --git a/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/Base-DeepLabV3-OS16-Semantic.yaml b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/Base-DeepLabV3-OS16-Semantic.yaml new file mode 100644 index 0000000..fa6edb5 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/Base-DeepLabV3-OS16-Semantic.yaml @@ -0,0 +1,36 @@ +_BASE_: "../../../../configs/Base-RCNN-DilatedC5.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + BACKBONE: + FREEZE_AT: 0 + SEM_SEG_HEAD: + NAME: "DeepLabV3Head" + IN_FEATURES: ["res5"] + ASPP_CHANNELS: 256 + ASPP_DILATIONS: [6, 12, 18] + ASPP_DROPOUT: 0.1 + CONVS_DIM: 256 + COMMON_STRIDE: 16 + NUM_CLASSES: 19 + LOSS_TYPE: "hard_pixel_mining" +DATASETS: + TRAIN: ("cityscapes_fine_sem_seg_train",) + TEST: ("cityscapes_fine_sem_seg_val",) +SOLVER: + BASE_LR: 0.01 + MAX_ITER: 90000 + LR_SCHEDULER_NAME: "WarmupPolyLR" + IMS_PER_BATCH: 16 +INPUT: + MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048) + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 1024 + MAX_SIZE_TRAIN: 4096 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (512, 1024) + SINGLE_CATEGORY_MAX_AREA: 1.0 +DATALOADER: + NUM_WORKERS: 10 diff --git a/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_R_103_os16_mg124_poly_90k_bs16.yaml b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_R_103_os16_mg124_poly_90k_bs16.yaml new file mode 100644 index 0000000..a2f5a54 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_R_103_os16_mg124_poly_90k_bs16.yaml @@ -0,0 +1,19 @@ +_BASE_: Base-DeepLabV3-OS16-Semantic.yaml +MODEL: + WEIGHTS: "detectron2://DeepLab/R-103.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + BACKBONE: + NAME: "build_resnet_deeplab_backbone" + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + RES5_MULTI_GRID: [1, 2, 4] + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False + SEM_SEG_HEAD: + NAME: "DeepLabV3Head" + NORM: "SyncBN" +INPUT: + FORMAT: "RGB" diff --git a/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16.yaml b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16.yaml new file mode 100644 index 0000000..c03a72d --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/configs/Cityscapes-SemanticSegmentation/deeplab_v3_plus_R_103_os16_mg124_poly_90k_bs16.yaml @@ -0,0 +1,24 @@ +_BASE_: Base-DeepLabV3-OS16-Semantic.yaml +MODEL: + WEIGHTS: "detectron2://DeepLab/R-103.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + BACKBONE: + NAME: "build_resnet_deeplab_backbone" + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + OUT_FEATURES: ["res2", "res5"] + RES5_MULTI_GRID: [1, 2, 4] + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False + SEM_SEG_HEAD: + NAME: "DeepLabV3PlusHead" + IN_FEATURES: ["res2", "res5"] + PROJECT_FEATURES: ["res2"] + PROJECT_CHANNELS: [48] + NORM: "SyncBN" + COMMON_STRIDE: 4 +INPUT: + FORMAT: "RGB" diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/__init__.py b/data_processing/detectron2/projects/DeepLab/deeplab/__init__.py new file mode 100644 index 0000000..dcd88ff --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build_solver import build_lr_scheduler +from .config import add_deeplab_config +from .resnet import build_resnet_deeplab_backbone +from .semantic_seg import DeepLabV3Head, DeepLabV3PlusHead diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/build_solver.py b/data_processing/detectron2/projects/DeepLab/deeplab/build_solver.py new file mode 100644 index 0000000..a1d359c --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/build_solver.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from detectron2.config import CfgNode +from detectron2.solver import LRScheduler +from detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler + +from .lr_scheduler import WarmupPolyLR + + +def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler: + """ + Build a LR scheduler from config. + """ + name = cfg.SOLVER.LR_SCHEDULER_NAME + if name == "WarmupPolyLR": + return WarmupPolyLR( + optimizer, + cfg.SOLVER.MAX_ITER, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + power=cfg.SOLVER.POLY_LR_POWER, + constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING, + ) + else: + return build_d2_lr_scheduler(cfg, optimizer) diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/config.py b/data_processing/detectron2/projects/DeepLab/deeplab/config.py new file mode 100644 index 0000000..5f5e45a --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/config.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + + +def add_deeplab_config(cfg): + """ + Add config for DeepLab. + """ + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Used for `poly` learning rate schedule. + cfg.SOLVER.POLY_LR_POWER = 0.9 + cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0 + # Loss type, choose from `cross_entropy`, `hard_pixel_mining`. + cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE = "hard_pixel_mining" + # DeepLab settings + cfg.MODEL.SEM_SEG_HEAD.PROJECT_FEATURES = ["res2"] + cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS = [48] + cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS = 256 + cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS = [6, 12, 18] + cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT = 0.1 + cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV = False + # Backbone new configs + cfg.MODEL.RESNETS.RES4_DILATION = 1 + cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 2, 4] + # ResNet stem type from: `basic`, `deeplab` + cfg.MODEL.RESNETS.STEM_TYPE = "deeplab" diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/loss.py b/data_processing/detectron2/projects/DeepLab/deeplab/loss.py new file mode 100644 index 0000000..3a43087 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/loss.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +import torch.nn as nn + + +class DeepLabCE(nn.Module): + """ + Hard pixel mining with cross entropy loss, for semantic segmentation. + This is used in TensorFlow DeepLab frameworks. + Paper: DeeperLab: Single-Shot Image Parser + Reference: https://github.com/tensorflow/models/blob/bd488858d610e44df69da6f89277e9de8a03722c/research/deeplab/utils/train_utils.py#L33 # noqa + Arguments: + ignore_label: Integer, label to ignore. + top_k_percent_pixels: Float, the value lies in [0.0, 1.0]. When its + value < 1.0, only compute the loss for the top k percent pixels + (e.g., the top 20% pixels). This is useful for hard pixel mining. + weight: Tensor, a manual rescaling weight given to each class. + """ + + def __init__(self, ignore_label=-1, top_k_percent_pixels=1.0, weight=None): + super(DeepLabCE, self).__init__() + self.top_k_percent_pixels = top_k_percent_pixels + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, ignore_index=ignore_label, reduction="none" + ) + + def forward(self, logits, labels, weights=None): + if weights is None: + pixel_losses = self.criterion(logits, labels).contiguous().view(-1) + else: + # Apply per-pixel loss weights. + pixel_losses = self.criterion(logits, labels) * weights + pixel_losses = pixel_losses.contiguous().view(-1) + if self.top_k_percent_pixels == 1.0: + return pixel_losses.mean() + + top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel()) + pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels) + return pixel_losses.mean() diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/lr_scheduler.py b/data_processing/detectron2/projects/DeepLab/deeplab/lr_scheduler.py new file mode 100644 index 0000000..b754b59 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/lr_scheduler.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List +import torch + +from detectron2.solver.lr_scheduler import LRScheduler, _get_warmup_factor_at_iter + +# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes +# only on epoch boundaries. We typically use iteration based schedules instead. +# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean +# "iteration" instead. + +# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating +# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. + + +class WarmupPolyLR(LRScheduler): + """ + Poly learning rate schedule used to train DeepLab. + Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, + Atrous Convolution, and Fully Connected CRFs. + Reference: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/utils/train_utils.py#L337 # noqa + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + max_iters: int, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + power: float = 0.9, + constant_ending: float = 0.0, + ): + self.max_iters = max_iters + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.power = power + self.constant_ending = constant_ending + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + if self.constant_ending > 0 and warmup_factor == 1.0: + # Constant ending lr. + if ( + math.pow((1.0 - self.last_epoch / self.max_iters), self.power) + < self.constant_ending + ): + return [base_lr * self.constant_ending for base_lr in self.base_lrs] + return [ + base_lr * warmup_factor * math.pow((1.0 - self.last_epoch / self.max_iters), self.power) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/resnet.py b/data_processing/detectron2/projects/DeepLab/deeplab/resnet.py new file mode 100644 index 0000000..2cc277b --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/resnet.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import fvcore.nn.weight_init as weight_init +import torch.nn.functional as F + +from detectron2.layers import CNNBlockBase, Conv2d, get_norm +from detectron2.modeling import BACKBONE_REGISTRY +from detectron2.modeling.backbone.resnet import ( + BasicStem, + BottleneckBlock, + DeformBottleneckBlock, + ResNet, +) + + +class DeepLabStem(CNNBlockBase): + """ + The DeepLab ResNet stem (layers before the first residual block). + """ + + def __init__(self, in_channels=3, out_channels=128, norm="BN"): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = Conv2d( + in_channels, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, + norm=get_norm(norm, out_channels // 2), + ) + self.conv2 = Conv2d( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels // 2), + ) + self.conv3 = Conv2d( + out_channels // 2, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + weight_init.c2_msra_fill(self.conv2) + weight_init.c2_msra_fill(self.conv3) + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + x = self.conv2(x) + x = F.relu_(x) + x = self.conv3(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +@BACKBONE_REGISTRY.register() +def build_resnet_deeplab_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + if cfg.MODEL.RESNETS.STEM_TYPE == "basic": + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab": + stem = DeepLabStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + else: + raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE)) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID + # fmt: on + assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation) + assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation) + if res4_dilation == 2: + # Always dilate res5 if res4 is dilated. + assert res5_dilation == 4 + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + if stage_idx == 4: + dilation = res4_dilation + elif stage_idx == 5: + dilation = res5_dilation + else: + dilation = 1 + first_stride = 1 if idx == 0 or dilation > 1 else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + if stage_idx == 5: + stage_kargs.pop("dilation") + stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid] + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) diff --git a/data_processing/detectron2/projects/DeepLab/deeplab/semantic_seg.py b/data_processing/detectron2/projects/DeepLab/deeplab/semantic_seg.py new file mode 100644 index 0000000..d4625c5 --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/deeplab/semantic_seg.py @@ -0,0 +1,348 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Callable, Dict, List, Optional, Tuple, Union +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import ASPP, Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from .loss import DeepLabCE + + +@SEM_SEG_HEADS_REGISTRY.register() +class DeepLabV3PlusHead(nn.Module): + """ + A semantic segmentation head described in :paper:`DeepLabV3+`. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + project_channels: List[int], + aspp_dilations: List[int], + aspp_dropout: float, + decoder_channels: List[int], + common_stride: int, + norm: Union[str, Callable], + train_size: Optional[Tuple], + loss_weight: float = 1.0, + loss_type: str = "cross_entropy", + ignore_value: int = -1, + num_classes: Optional[int] = None, + use_depthwise_separable_conv: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape: shape of the input features. They will be ordered by stride + and the last one (with largest stride) is used as the input to the + decoder (i.e. the ASPP module); the rest are low-level feature for + the intermediate levels of decoder. + project_channels (list[int]): a list of low-level feature channels. + The length should be len(in_features) - 1. + aspp_dilations (list(int)): a list of 3 dilations in ASPP. + aspp_dropout (float): apply dropout on the output of ASPP. + decoder_channels (list[int]): a list of output channels of each + decoder stage. It should have the same length as "in_features" + (each element in "in_features" corresponds to one decoder stage). + common_stride (int): output stride of decoder. + norm (str or callable): normalization for all conv layers. + train_size (tuple): (height, width) of training images. + loss_weight (float): loss weight. + loss_type (str): type of loss function, 2 opptions: + (1) "cross_entropy" is the standard cross entropy loss. + (2) "hard_pixel_mining" is the loss in DeepLab that samples + top k% hardest pixels. + ignore_value (int): category to be ignored during training. + num_classes (int): number of classes, if set to None, the decoder + will not construct a predictor. + use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d + in ASPP and decoder. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + + # fmt: off + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + in_channels = [x[1].channels for x in input_shape] + in_strides = [x[1].stride for x in input_shape] + aspp_channels = decoder_channels[-1] + self.ignore_value = ignore_value + self.common_stride = common_stride # output stride + self.loss_weight = loss_weight + self.loss_type = loss_type + self.decoder_only = num_classes is None + self.use_depthwise_separable_conv = use_depthwise_separable_conv + # fmt: on + + assert ( + len(project_channels) == len(self.in_features) - 1 + ), "Expected {} project_channels, got {}".format( + len(self.in_features) - 1, len(project_channels) + ) + assert len(decoder_channels) == len( + self.in_features + ), "Expected {} decoder_channels, got {}".format( + len(self.in_features), len(decoder_channels) + ) + self.decoder = nn.ModuleDict() + + use_bias = norm == "" + for idx, in_channel in enumerate(in_channels): + decoder_stage = nn.ModuleDict() + + if idx == len(self.in_features) - 1: + # ASPP module + if train_size is not None: + train_h, train_w = train_size + encoder_stride = in_strides[-1] + if train_h % encoder_stride or train_w % encoder_stride: + raise ValueError("Crop size need to be divisible by encoder stride.") + pool_h = train_h // encoder_stride + pool_w = train_w // encoder_stride + pool_kernel_size = (pool_h, pool_w) + else: + pool_kernel_size = None + project_conv = ASPP( + in_channel, + aspp_channels, + aspp_dilations, + norm=norm, + activation=F.relu, + pool_kernel_size=pool_kernel_size, + dropout=aspp_dropout, + use_depthwise_separable_conv=use_depthwise_separable_conv, + ) + fuse_conv = None + else: + project_conv = Conv2d( + in_channel, + project_channels[idx], + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, project_channels[idx]), + activation=F.relu, + ) + weight_init.c2_xavier_fill(project_conv) + if use_depthwise_separable_conv: + # We use a single 5x5 DepthwiseSeparableConv2d to replace + # 2 3x3 Conv2d since they have the same receptive field, + # proposed in :paper:`Panoptic-DeepLab`. + fuse_conv = DepthwiseSeparableConv2d( + project_channels[idx] + decoder_channels[idx + 1], + decoder_channels[idx], + kernel_size=5, + padding=2, + norm1=norm, + activation1=F.relu, + norm2=norm, + activation2=F.relu, + ) + else: + fuse_conv = nn.Sequential( + Conv2d( + project_channels[idx] + decoder_channels[idx + 1], + decoder_channels[idx], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[idx]), + activation=F.relu, + ), + Conv2d( + decoder_channels[idx], + decoder_channels[idx], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[idx]), + activation=F.relu, + ), + ) + weight_init.c2_xavier_fill(fuse_conv[0]) + weight_init.c2_xavier_fill(fuse_conv[1]) + + decoder_stage["project_conv"] = project_conv + decoder_stage["fuse_conv"] = fuse_conv + + self.decoder[self.in_features[idx]] = decoder_stage + + if not self.decoder_only: + self.predictor = Conv2d( + decoder_channels[0], num_classes, kernel_size=1, stride=1, padding=0 + ) + nn.init.normal_(self.predictor.weight, 0, 0.001) + nn.init.constant_(self.predictor.bias, 0) + + if self.loss_type == "cross_entropy": + self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value) + elif self.loss_type == "hard_pixel_mining": + self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2) + else: + raise ValueError("Unexpected loss type: %s" % self.loss_type) + + @classmethod + def from_config(cls, cfg, input_shape): + if cfg.INPUT.CROP.ENABLED: + assert cfg.INPUT.CROP.TYPE == "absolute" + train_size = cfg.INPUT.CROP.SIZE + else: + train_size = None + decoder_channels = [cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM] * ( + len(cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES) - 1 + ) + [cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS] + ret = dict( + input_shape={ + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + project_channels=cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS, + aspp_dilations=cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS, + aspp_dropout=cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT, + decoder_channels=decoder_channels, + common_stride=cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE, + norm=cfg.MODEL.SEM_SEG_HEAD.NORM, + train_size=train_size, + loss_weight=cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + loss_type=cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE, + ignore_value=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV, + ) + return ret + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + y = self.layers(features) + if self.decoder_only: + # Output from self.layers() only contains decoder feature. + return y + if self.training: + return None, self.losses(y, targets) + else: + y = F.interpolate( + y, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return y, {} + + def layers(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for f in self.in_features[::-1]: + x = features[f] + proj_x = self.decoder[f]["project_conv"](x) + if self.decoder[f]["fuse_conv"] is None: + # This is aspp module + y = proj_x + else: + # Upsample y + y = F.interpolate(y, size=proj_x.size()[2:], mode="bilinear", align_corners=False) + y = torch.cat([proj_x, y], dim=1) + y = self.decoder[f]["fuse_conv"](y) + if not self.decoder_only: + y = self.predictor(y) + return y + + def losses(self, predictions, targets): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.loss(predictions, targets) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses + + +@SEM_SEG_HEADS_REGISTRY.register() +class DeepLabV3Head(nn.Module): + """ + A semantic segmentation head described in :paper:`DeepLabV3`. + """ + + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + + # fmt: off + self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + in_channels = [input_shape[f].channels for f in self.in_features] + aspp_channels = cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS + aspp_dilations = cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS + self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE + num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES + conv_dims = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + self.common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE # output stride + norm = cfg.MODEL.SEM_SEG_HEAD.NORM + self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT + self.loss_type = cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE + train_crop_size = cfg.INPUT.CROP.SIZE + aspp_dropout = cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT + use_depthwise_separable_conv = cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV + # fmt: on + + assert len(self.in_features) == 1 + assert len(in_channels) == 1 + + # ASPP module + if cfg.INPUT.CROP.ENABLED: + assert cfg.INPUT.CROP.TYPE == "absolute" + train_crop_h, train_crop_w = train_crop_size + if train_crop_h % self.common_stride or train_crop_w % self.common_stride: + raise ValueError("Crop size need to be divisible by output stride.") + pool_h = train_crop_h // self.common_stride + pool_w = train_crop_w // self.common_stride + pool_kernel_size = (pool_h, pool_w) + else: + pool_kernel_size = None + self.aspp = ASPP( + in_channels[0], + aspp_channels, + aspp_dilations, + norm=norm, + activation=F.relu, + pool_kernel_size=pool_kernel_size, + dropout=aspp_dropout, + use_depthwise_separable_conv=use_depthwise_separable_conv, + ) + + self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) + nn.init.normal_(self.predictor.weight, 0, 0.001) + nn.init.constant_(self.predictor.bias, 0) + + if self.loss_type == "cross_entropy": + self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value) + elif self.loss_type == "hard_pixel_mining": + self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2) + else: + raise ValueError("Unexpected loss type: %s" % self.loss_type) + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = features[self.in_features[0]] + x = self.aspp(x) + x = self.predictor(x) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def losses(self, predictions, targets): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.loss(predictions, targets) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses diff --git a/data_processing/detectron2/projects/DeepLab/train_net.py b/data_processing/detectron2/projects/DeepLab/train_net.py new file mode 100644 index 0000000..d3414dd --- /dev/null +++ b/data_processing/detectron2/projects/DeepLab/train_net.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +DeepLab Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" + +import os + +import detectron2.data.transforms as T +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import DatasetMapper, MetadataCatalog, build_detection_train_loader +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import CityscapesSemSegEvaluator, DatasetEvaluators, SemSegEvaluator +from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler + + +def build_sem_seg_train_aug(cfg): + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append( + T.RandomCrop_CategoryAreaConstraint( + cfg.INPUT.CROP.TYPE, + cfg.INPUT.CROP.SIZE, + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, + cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + ) + ) + augs.append(T.RandomFlip()) + return augs + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains a number pre-defined logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can use the cleaner + "SimpleTrainer", or write your own training loop. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type == "sem_seg": + return SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + if evaluator_type == "cityscapes_sem_seg": + return CityscapesSemSegEvaluator(dataset_name) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format( + dataset_name, evaluator_type + ) + ) + if len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + @classmethod + def build_train_loader(cls, cfg): + if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE: + mapper = DatasetMapper(cfg, is_train=True, augmentations=build_sem_seg_train_aug(cfg)) + else: + mapper = None + return build_detection_train_loader(cfg, mapper=mapper) + + @classmethod + def build_lr_scheduler(cls, cfg, optimizer): + """ + It now calls :func:`detectron2.solver.build_lr_scheduler`. + Overwrite it if you'd like a different scheduler. + """ + return build_lr_scheduler(cfg, optimizer) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_deeplab_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/DensePose/DensePoseData/UV_Processed.mat b/data_processing/detectron2/projects/DensePose/DensePoseData/UV_Processed.mat new file mode 100644 index 0000000..ddaae2e Binary files /dev/null and b/data_processing/detectron2/projects/DensePose/DensePoseData/UV_Processed.mat differ diff --git a/data_processing/detectron2/projects/DensePose/DensePoseData/UV_symmetry_transforms.mat b/data_processing/detectron2/projects/DensePose/DensePoseData/UV_symmetry_transforms.mat new file mode 100644 index 0000000..d09d70f Binary files /dev/null and b/data_processing/detectron2/projects/DensePose/DensePoseData/UV_symmetry_transforms.mat differ diff --git a/data_processing/detectron2/projects/DensePose/DensePoseData/densepose_uv_data.tar.gz b/data_processing/detectron2/projects/DensePose/DensePoseData/densepose_uv_data.tar.gz new file mode 100644 index 0000000..a56333c Binary files /dev/null and b/data_processing/detectron2/projects/DensePose/DensePoseData/densepose_uv_data.tar.gz differ diff --git a/data_processing/detectron2/projects/DensePose/README.md b/data_processing/detectron2/projects/DensePose/README.md new file mode 100644 index 0000000..38f4f83 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/README.md @@ -0,0 +1,64 @@ +# DensePose in Detectron2 + +DensePose aims at learning and establishing dense correspondences between image pixels +and 3D object geometry for deformable objects, such as humans or animals. +In this repository, we provide the code to train and evaluate DensePose R-CNN and +various tools to visualize DensePose annotations and results. + +There are two main paradigms that are used within DensePose project. + +## [Chart-based Dense Pose Estimation for Humans and Animals](doc/DENSEPOSE_IUV.md) + +
+ +
+ +For chart-based estimation, 3D object mesh is split into charts and +for each pixel the model estimates chart index `I` and local chart coordinates `(U, V)`. +Please follow the link above to find a [detailed overview](doc/DENSEPOSE_IUV.md#Overview) +of the method, links to trained models along with their performance evaluation in the +[Model Zoo](doc/DENSEPOSE_IUV.md#ModelZoo) and +[references](doc/DENSEPOSE_IUV.md#References) to the corresponding papers. + +## [Continuous Surface Embeddings for Dense Pose Estimation for Humans and Animals](doc/DENSEPOSE_CSE.md) + +
+ +
+ +To establish continuous surface embeddings, the model simultaneously learns +descriptors for mesh vertices and for image pixels. +The embeddings are put into correspondence, thus the location +of each pixel on the 3D model is derived. +Please follow the link above to find a [detailed overview](doc/DENSEPOSE_CSE.md#Overview) +of the method, links to trained models along with their performance evaluation in the +[Model Zoo](doc/DENSEPOSE_CSE.md#ModelZoo) and +[references](doc/DENSEPOSE_CSE.md#References) to the corresponding papers. + +# Quick Start + +See [ Getting Started ](doc/GETTING_STARTED.md) + +# Model Zoo + +Please check the dedicated pages +for [chart-based model zoo](doc/DENSEPOSE_IUV.md#ModelZoo) +and for [continuous surface embeddings model zoo](doc/DENSEPOSE_CSE.md#ModelZoo). + +# What's New + +* June 2021: [DensePose CSE with Cycle Losses](doc/RELEASE_2021_06.md) +* March 2021: [DensePose CSE (a framework to extend DensePose to various categories using 3D models) + and DensePose Evolution (a framework to bootstrap DensePose on unlabeled data) released](doc/RELEASE_2021_03.md) +* April 2020: [DensePose Confidence Estimation and Model Zoo Improvements](doc/RELEASE_2020_04.md) + +# License + +Detectron2 is released under the [Apache 2.0 license](../../LICENSE) + +## Citing DensePose + +If you use DensePose, please refer to the BibTeX entries +for [chart-based models](doc/DENSEPOSE_IUV.md#References) +and for [continuous surface embeddings](doc/DENSEPOSE_CSE.md#References). + diff --git a/data_processing/detectron2/projects/DensePose/apply_net.py b/data_processing/detectron2/projects/DensePose/apply_net.py new file mode 100644 index 0000000..c854f4e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/apply_net.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +import argparse +import glob +import logging +import os +import sys +from typing import Any, ClassVar, Dict, List +import torch + +from detectron2.config import CfgNode, get_cfg +from detectron2.data.detection_utils import read_image +from detectron2.engine.defaults import DefaultPredictor +from detectron2.structures.instances import Instances +from detectron2.utils.logger import setup_logger + +from densepose import add_densepose_config +from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput +from densepose.utils.logger import verbosity_to_level +from densepose.vis.base import CompoundVisualizer +from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer +from densepose.vis.densepose_outputs_vertex import ( + DensePoseOutputsTextureVisualizer, + DensePoseOutputsVertexVisualizer, + get_texture_atlases, +) +from densepose.vis.densepose_results import ( + DensePoseResultsContourVisualizer, + DensePoseResultsFineSegmentationVisualizer, + DensePoseResultsUVisualizer, + DensePoseResultsVVisualizer, +) +from densepose.vis.densepose_results_textures import ( + DensePoseResultsVisualizerWithTexture, + get_texture_atlas, +) +from densepose.vis.extractor import ( + CompoundExtractor, + DensePoseOutputsExtractor, + DensePoseResultExtractor, + create_extractor, +) + +DOC = """Apply Net - a tool to print / visualize DensePose results +""" + +LOGGER_NAME = "apply_net" +logger = logging.getLogger(LOGGER_NAME) + +_ACTION_REGISTRY: Dict[str, "Action"] = {} + + +class Action(object): + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + parser.add_argument( + "-v", + "--verbosity", + action="count", + help="Verbose mode. Multiple -v options increase the verbosity.", + ) + + +def register_action(cls: type): + """ + Decorator for action classes to automate action registration + """ + global _ACTION_REGISTRY + _ACTION_REGISTRY[cls.COMMAND] = cls + return cls + + +class InferenceAction(Action): + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(InferenceAction, cls).add_arguments(parser) + parser.add_argument("cfg", metavar="", help="Config file") + parser.add_argument("model", metavar="", help="Model file") + parser.add_argument("input", metavar="", help="Input data") + parser.add_argument( + "--opts", + help="Modify config options using the command-line 'KEY VALUE' pairs", + default=[], + nargs=argparse.REMAINDER, + ) + + @classmethod + def execute(cls: type, args: argparse.Namespace): + logger.info(f"Loading config from {args.cfg}") + opts = [] + cfg = cls.setup_config(args.cfg, args.model, args, opts) + logger.info(f"Loading model from {args.model}") + predictor = DefaultPredictor(cfg) + logger.info(f"Loading data from {args.input}") + file_list = cls._get_input_file_list(args.input) + if len(file_list) == 0: + logger.warning(f"No input images for {args.input}") + return + context = cls.create_context(args, cfg) + + from tqdm import tqdm + for file_name in tqdm(file_list): + img = read_image(file_name, format="BGR") # predictor expects BGR image. + with torch.no_grad(): + outputs = predictor(img)["instances"] + cls.execute_on_outputs(context, {"file_name": file_name, "image": img}, outputs) + cls.postexecute(context) + + @classmethod + def setup_config( + cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str] + ): + cfg = get_cfg() + add_densepose_config(cfg) + cfg.merge_from_file(config_fpath) + cfg.merge_from_list(args.opts) + if opts: + cfg.merge_from_list(opts) + cfg.MODEL.WEIGHTS = model_fpath + cfg.freeze() + return cfg + + @classmethod + def _get_input_file_list(cls: type, input_spec: str): + #print('input_spec: ', input_spec) + + if os.path.isdir(input_spec): + file_list = [ + os.path.join(input_spec, fname) + for fname in os.listdir(input_spec) + if os.path.isfile(os.path.join(input_spec, fname)) + ] + elif os.path.isfile(input_spec): + file_list = [input_spec] + else: + file_list = glob.glob(input_spec) + + #print('file_list: ', file_list) + return file_list + + +@register_action +class DumpAction(InferenceAction): + """ + Dump action that outputs results to a pickle file + """ + + COMMAND: ClassVar[str] = "dump" + + @classmethod + def add_parser(cls: type, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.") + cls.add_arguments(parser) + parser.set_defaults(func=cls.execute) + + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(DumpAction, cls).add_arguments(parser) + parser.add_argument( + "--output", + metavar="", + default="results.pkl", + help="File name to save dump to", + ) + + @classmethod + def execute_on_outputs( + cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances + ): + image_fpath = entry["file_name"] + logger.info(f"Processing {image_fpath}") + result = {"file_name": image_fpath} + if outputs.has("scores"): + result["scores"] = outputs.get("scores").cpu() + if outputs.has("pred_boxes"): + result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu() + if outputs.has("pred_densepose"): + if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput): + extractor = DensePoseResultExtractor() + elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput): + extractor = DensePoseOutputsExtractor() + result["pred_densepose"] = extractor(outputs)[0] + context["results"].append(result) + + @classmethod + def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode): + context = {"results": [], "out_fname": args.output} + return context + + @classmethod + def postexecute(cls: type, context: Dict[str, Any]): + out_fname = context["out_fname"] + out_dir = os.path.dirname(out_fname) + if len(out_dir) > 0 and not os.path.exists(out_dir): + os.makedirs(out_dir) + with open(out_fname, "wb") as hFile: + torch.save(context["results"], hFile) + logger.info(f"Output saved to {out_fname}") + + +@register_action +class ShowAction(InferenceAction): + """ + Show action that visualizes selected entries on an image + """ + + COMMAND: ClassVar[str] = "show" + VISUALIZERS: ClassVar[Dict[str, object]] = { + "dp_contour": DensePoseResultsContourVisualizer, + "dp_segm": DensePoseResultsFineSegmentationVisualizer, + "dp_u": DensePoseResultsUVisualizer, + "dp_v": DensePoseResultsVVisualizer, + "dp_iuv_texture": DensePoseResultsVisualizerWithTexture, + "dp_cse_texture": DensePoseOutputsTextureVisualizer, + "dp_vertex": DensePoseOutputsVertexVisualizer, + "bbox": ScoredBoundingBoxVisualizer, + } + + @classmethod + def add_parser(cls: type, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries") + cls.add_arguments(parser) + parser.set_defaults(func=cls.execute) + + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(ShowAction, cls).add_arguments(parser) + parser.add_argument( + "visualizations", + metavar="", + help="Comma separated list of visualizations, possible values: " + "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))), + ) + parser.add_argument( + "--min_score", + metavar="", + default=0.8, + type=float, + help="Minimum detection score to visualize", + ) + parser.add_argument( + "--nms_thresh", metavar="", default=None, type=float, help="NMS threshold" + ) + parser.add_argument( + "--texture_atlas", + metavar="", + default=None, + help="Texture atlas file (for IUV texture transfer)", + ) + parser.add_argument( + "--texture_atlases_map", + metavar="", + default=None, + help="JSON string of a dict containing texture atlas files for each mesh", + ) + parser.add_argument( + "--output", + metavar="", + default="outputres.png", + help="File name to save output to", + ) + + @classmethod + def setup_config( + cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str] + ): + opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST") + opts.append(str(args.min_score)) + if args.nms_thresh is not None: + opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST") + opts.append(str(args.nms_thresh)) + cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts) + return cfg + + @classmethod + def execute_on_outputs( + cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances + ): + import cv2 + import numpy as np + + visualizer = context["visualizer"] + extractor = context["extractor"] + image_fpath = entry["file_name"] + logger.info(f"Processing {image_fpath}") + image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY) + image = np.tile(image[:, :, np.newaxis], [1, 1, 3]) + scores = outputs.scores + + data = extractor(outputs) + # if "pexels-photo-3268732_1" not in image_fpath: + # return + # print('scores', scores) + image = np.ones_like(image)*255 + image_vis_list = visualizer.visualize(image, data) + # check if image_vis is all white + for i, image_vis in enumerate(image_vis_list): + + if image_vis is not None : + + all_white = np.all(image_vis[:,:,2] == 255) + #print(np.sum(image_vis[:,:,2] == 255) / (image_vis.shape[0] * image_vis.shape[1])) + most_white = np.sum(image_vis[:,:,2] == 255) / (image_vis.shape[0] * image_vis.shape[1]) > 0.9 + + if all_white or most_white: + #print('all white') + continue + + entry_idx = context["entry_idx"] + 1 + # out_fname = cls._get_out_fname(entry_idx, context["out_fname"]) + + out_fname = os.path.basename(image_fpath).split('.')[0] + f'_seg{i}.png' + + out_dir = context["out_fname"] # os.path.dirname(out_fname) + out_fname = os.path.join(out_dir, out_fname) + if len(out_dir) > 0 and not os.path.exists(out_dir): + os.makedirs(out_dir) + cv2.imwrite(out_fname, image_vis) + logger.info(f"Output saved to {out_fname}") + context["entry_idx"] += 1 + + + + @classmethod + def postexecute(cls: type, context: Dict[str, Any]): + pass + + @classmethod + def _get_out_fname(cls: type, entry_idx: int, fname_base: str): + #print('fname_base: ', fname_base) + base, ext = os.path.splitext(fname_base) + return base + ".{0:04d}".format(entry_idx) + ext + + @classmethod + def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]: + vis_specs = args.visualizations.split(",") + visualizers = [] + extractors = [] + for vis_spec in vis_specs: + texture_atlas = get_texture_atlas(args.texture_atlas) + texture_atlases_dict = get_texture_atlases(args.texture_atlases_map) + vis = cls.VISUALIZERS[vis_spec]( + cfg=cfg, + texture_atlas=texture_atlas, + texture_atlases_dict=texture_atlases_dict, + ) + visualizers.append(vis) + extractor = create_extractor(vis) + extractors.append(extractor) + visualizer = CompoundVisualizer(visualizers) + extractor = CompoundExtractor(extractors) + context = { + "extractor": extractor, + "visualizer": visualizer, + "out_fname": args.output, + "entry_idx": 0, + } + return context + + +def create_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=DOC, + formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120), + ) + parser.set_defaults(func=lambda _: parser.print_help(sys.stdout)) + subparsers = parser.add_subparsers(title="Actions") + for _, action in _ACTION_REGISTRY.items(): + action.add_parser(subparsers) + return parser + + +def main(): + parser = create_argument_parser() + args = parser.parse_args() + verbosity = getattr(args, "verbosity", None) + global logger + logger = setup_logger(name=LOGGER_NAME) + logger.setLevel(verbosity_to_level(verbosity)) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/data_processing/detectron2/projects/DensePose/configs/Base-DensePose-RCNN-FPN.yaml b/data_processing/detectron2/projects/DensePose/configs/Base-DensePose-RCNN-FPN.yaml new file mode 100644 index 0000000..1579187 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/Base-DensePose-RCNN-FPN.yaml @@ -0,0 +1,48 @@ +VERSION: 2 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 +DATASETS: + TRAIN: ("densepose_coco_2014_train", "densepose_coco_2014_valminusminival") + TEST: ("densepose_coco_2014_minival",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (60000, 80000) + MAX_ITER: 90000 + WARMUP_FACTOR: 0.1 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) diff --git a/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w32_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w32_s1x.yaml new file mode 100644 index 0000000..36eabfe --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w32_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://1drv.ms/u/s!Aus8VCZ_C_33dYBMemi9xOUFR0w" + BACKBONE: + NAME: "build_hrfpn_backbone" + RPN: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] + ROI_HEADS: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "norm" + BASE_LR: 0.03 diff --git a/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w40_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w40_s1x.yaml new file mode 100644 index 0000000..0ca8085 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w40_s1x.yaml @@ -0,0 +1,23 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://1drv.ms/u/s!Aus8VCZ_C_33ck0gvo5jfoWBOPo" + BACKBONE: + NAME: "build_hrfpn_backbone" + RPN: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] + ROI_HEADS: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] + HRNET: + STAGE2: + NUM_CHANNELS: [40, 80] + STAGE3: + NUM_CHANNELS: [40, 80, 160] + STAGE4: + NUM_CHANNELS: [40, 80, 160, 320] +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "norm" + BASE_LR: 0.03 diff --git a/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w48_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w48_s1x.yaml new file mode 100644 index 0000000..a3f437a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/HRNet/densepose_rcnn_HRFPN_HRNet_w48_s1x.yaml @@ -0,0 +1,23 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk" + BACKBONE: + NAME: "build_hrfpn_backbone" + RPN: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] + ROI_HEADS: + IN_FEATURES: ['p1', 'p2', 'p3', 'p4', 'p5'] + HRNET: + STAGE2: + NUM_CHANNELS: [48, 96] + STAGE3: + NUM_CHANNELS: [48, 96, 192] + STAGE4: + NUM_CHANNELS: [48, 96, 192, 384] +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "norm" + BASE_LR: 0.03 diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml new file mode 100644 index 0000000..1d44a8a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml @@ -0,0 +1,21 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + ROI_DENSEPOSE_HEAD: + CSE: + EMBEDDERS: + "smpl_27554": + TYPE: vertex_feature + NUM_VERTICES: 27554 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + #INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_smpl_27554_256.pkl" + INIT_FILE: "./phi_smpl_27554_256.pkl" +DATASETS: + TRAIN: + - "densepose_coco_2014_train_cse" + - "densepose_coco_2014_valminusminival_cse" + TEST: + - "densepose_coco_2014_minival_cse" + CLASS_TO_MESH_NAME_MAPPING: + "0": "smpl_27554" diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml new file mode 100644 index 0000000..de3b260 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml @@ -0,0 +1,60 @@ +VERSION: 2 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + POOLER_SAMPLING_RATIO: 2 + POOLER_TYPE: "ROIAlign" + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + PREDICTOR_NAME: "DensePoseEmbeddingPredictor" + LOSS_NAME: "DensePoseCseLoss" + CSE: + # embedding loss, possible values: + # - "EmbeddingLoss" + # - "SoftEmbeddingLoss" + EMBED_LOSS_NAME: "EmbeddingLoss" +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (60000, 80000) + MAX_ITER: 90000 + WARMUP_FACTOR: 0.1 + CLIP_GRADIENTS: + CLIP_TYPE: norm + CLIP_VALUE: 1.0 + ENABLED: true + NORM_TYPE: 2.0 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DENSEPOSE_EVALUATION: + TYPE: cse + STORAGE: file diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_s1x.yaml new file mode 100644 index 0000000..69d8589 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + CSE: + EMBED_LOSS_NAME: "EmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml new file mode 100644 index 0000000..141657c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_s1x.yaml new file mode 100644 index 0000000..d2eea1e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "EmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_soft_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_soft_s1x.yaml new file mode 100644 index 0000000..1c362e1 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_soft_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_s1x.yaml new file mode 100644 index 0000000..26684de --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + CSE: + EMBED_LOSS_NAME: "EmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_soft_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_soft_s1x.yaml new file mode 100644 index 0000000..b53501d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_DL_soft_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml new file mode 100644 index 0000000..c186625 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "EmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_16k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_16k.yaml new file mode 100644 index 0000000..69ab226 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_16k.yaml @@ -0,0 +1,133 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CATEGORY_MAPS: + "densepose_lvis_v1_ds2_train_v1": + "1202": 943 # zebra -> sheep + "569": 943 # horse -> sheep + "496": 943 # giraffe -> sheep + "422": 943 # elephant -> sheep + "80": 943 # cow -> sheep + "76": 943 # bear -> sheep + "225": 943 # cat -> sheep + "378": 943 # dog -> sheep + "densepose_lvis_v1_ds2_val_v1": + "1202": 943 # zebra -> sheep + "569": 943 # horse -> sheep + "496": 943 # giraffe -> sheep + "422": 943 # elephant -> sheep + "80": 943 # cow -> sheep + "76": 943 # bear -> sheep + "225": 943 # cat -> sheep + "378": 943 # dog -> sheep + CLASS_TO_MESH_NAME_MAPPING: + # Note: different classes are mapped to a single class + # mesh is chosen based on GT data, so this is just some + # value which has no particular meaning + "0": "sheep_5004" +SOLVER: + MAX_ITER: 16000 + STEPS: (12000, 14000) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_4k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_4k.yaml new file mode 100644 index 0000000..921a9c1 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_4k.yaml @@ -0,0 +1,133 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_5001": + TYPE: vertex_feature + NUM_VERTICES: 5001 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_5001_256.pkl" + "dog_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_5002_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds1_train_v1" + TEST: + - "densepose_lvis_v1_ds1_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds1_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds1_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CATEGORY_MAPS: + "densepose_lvis_v1_ds1_train_v1": + "1202": 943 # zebra -> sheep + "569": 943 # horse -> sheep + "496": 943 # giraffe -> sheep + "422": 943 # elephant -> sheep + "80": 943 # cow -> sheep + "76": 943 # bear -> sheep + "225": 943 # cat -> sheep + "378": 943 # dog -> sheep + "densepose_lvis_v1_ds1_val_v1": + "1202": 943 # zebra -> sheep + "569": 943 # horse -> sheep + "496": 943 # giraffe -> sheep + "422": 943 # elephant -> sheep + "80": 943 # cow -> sheep + "76": 943 # bear -> sheep + "225": 943 # cat -> sheep + "378": 943 # dog -> sheep + CLASS_TO_MESH_NAME_MAPPING: + # Note: different classes are mapped to a single class + # mesh is chosen based on GT data, so this is just some + # value which has no particular meaning + "0": "sheep_5004" +SOLVER: + MAX_ITER: 4000 + STEPS: (3000, 3500) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_16k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_16k.yaml new file mode 100644 index 0000000..1b5a098 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_16k.yaml @@ -0,0 +1,119 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k/270668502/model_final_21b1d2.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_7466" + "3": "dog_7466" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 16000 + STEPS: (12000, 14000) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_i2m_16k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_i2m_16k.yaml new file mode 100644 index 0000000..18d6dac --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_i2m_16k.yaml @@ -0,0 +1,121 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k/270668502/model_final_21b1d2.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + PIX_TO_SHAPE_CYCLE_LOSS: + ENABLED: True + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_7466" + "3": "dog_7466" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 16000 + STEPS: (12000, 14000) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_m2m_16k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_m2m_16k.yaml new file mode 100644 index 0000000..6b798ae --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_m2m_16k.yaml @@ -0,0 +1,138 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k/267687159/model_final_354e61.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + SHAPE_TO_SHAPE_CYCLE_LOSS: + ENABLED: True + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" + "smpl_27554": + TYPE: vertex_feature + NUM_VERTICES: 27554 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_smpl_27554_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_7466" + "3": "dog_7466" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 16000 + STEPS: (12000, 14000) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True + MESH_ALIGNMENT_MESH_NAMES: + - bear_4936 + - cow_5002 + - cat_7466 + - dog_7466 + - elephant_5002 + - giraffe_5002 + - horse_5004 + - sheep_5004 + - zebra_5002 diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_16k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_16k.yaml new file mode 100644 index 0000000..b1462e3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_16k.yaml @@ -0,0 +1,119 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_7466" + "3": "dog_7466" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 16000 + STEPS: (12000, 14000) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_4k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_4k.yaml new file mode 100644 index 0000000..ba4b81d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_4k.yaml @@ -0,0 +1,119 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_5001": + TYPE: vertex_feature + NUM_VERTICES: 5001 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_5001_256.pkl" + "dog_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_5002_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds1_train_v1" + TEST: + - "densepose_lvis_v1_ds1_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds1_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds1_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_5001" + "3": "dog_5002" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 4000 + STEPS: (3000, 3500) +DENSEPOSE_EVALUATION: + EVALUATE_MESH_ALIGNMENT: True diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k.yaml new file mode 100644 index 0000000..bb6136e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k.yaml @@ -0,0 +1,118 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + COARSE_SEGM_TRAINED_BY_MASKS: True + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBED_LOSS_WEIGHT: 0.0 + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_7466_256.pkl" + "dog_7466": + TYPE: vertex_feature + NUM_VERTICES: 7466 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_7466_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_ds2_train_v1" + TEST: + - "densepose_lvis_v1_ds2_val_v1" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_ds2_train_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_ds2_val_v1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_7466" + "3": "dog_7466" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 24000 + STEPS: (20000, 22000) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_chimps_finetune_4k.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_chimps_finetune_4k.yaml new file mode 100644 index 0000000..3bccb78 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_chimps_finetune_4k.yaml @@ -0,0 +1,29 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_s1x/250533982/model_final_2c4512.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + GEODESIC_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "chimp_5029": + TYPE: vertex_feature + NUM_VERTICES: 5029 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_chimp_5029_256.pkl" +DATASETS: + TRAIN: + - "densepose_chimps_cse_train" + TEST: + - "densepose_chimps_cse_val" + CLASS_TO_MESH_NAME_MAPPING: + "0": "chimp_5029" +SOLVER: + MAX_ITER: 4000 + STEPS: (3000, 3500) diff --git a/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_s1x.yaml new file mode 100644 index 0000000..9662fb8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_soft_s1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-DensePose-RCNN-FPN-Human.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml new file mode 100644 index 0000000..3c16763 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1_s1x.yaml new file mode 100644 index 0000000..15475b1 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC1_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2M_s1x.yaml new file mode 100644 index 0000000..0cbe07f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2_s1x.yaml new file mode 100644 index 0000000..7546b96 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_WC2_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml new file mode 100644 index 0000000..045f7f0 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml @@ -0,0 +1,10 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1M_s1x.yaml new file mode 100644 index 0000000..9334e18 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1_s1x.yaml new file mode 100644 index 0000000..ace6209 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC1_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2M_s1x.yaml new file mode 100644 index 0000000..90f0be2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2_s1x.yaml new file mode 100644 index 0000000..766c098 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_WC2_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml new file mode 100644 index 0000000..af44fb7 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x_legacy.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x_legacy.yaml new file mode 100644 index 0000000..8e79a1b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_101_FPN_s1x_legacy.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + ROI_DENSEPOSE_HEAD: + NUM_COARSE_SEGM_CHANNELS: 15 + POOLER_RESOLUTION: 14 + HEATMAP_SIZE: 56 + INDEX_WEIGHTS: 2.0 + PART_WEIGHTS: 0.3 + POINT_REGRESSION_WEIGHTS: 0.1 + DECODER_ON: False +SOLVER: + BASE_LR: 0.002 + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1M_s1x.yaml new file mode 100644 index 0000000..18a417a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1_s1x.yaml new file mode 100644 index 0000000..f3720ef --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC1_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2M_s1x.yaml new file mode 100644 index 0000000..8a413d2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2_s1x.yaml new file mode 100644 index 0000000..5a47cc0 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_WC2_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml new file mode 100644 index 0000000..52a170b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml @@ -0,0 +1,10 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1M_s1x.yaml new file mode 100644 index 0000000..8a81f2a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1M_s1x.yaml @@ -0,0 +1,20 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: norm + CLIP_VALUE: 100.0 + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1_s1x.yaml new file mode 100644 index 0000000..d36e542 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC1_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2M_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2M_s1x.yaml new file mode 100644 index 0000000..5cf29ea --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2M_s1x.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2_s1x.yaml new file mode 100644 index 0000000..e880d46 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_WC2_s1x.yaml @@ -0,0 +1,16 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + POINT_REGRESSION_WEIGHTS: 0.0005 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 130000 + STEPS: (100000, 120000) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml new file mode 100644 index 0000000..d2dd14c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x_legacy.yaml b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x_legacy.yaml new file mode 100644 index 0000000..6c5391f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x_legacy.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + NUM_COARSE_SEGM_CHANNELS: 15 + POOLER_RESOLUTION: 14 + HEATMAP_SIZE: 56 + INDEX_WEIGHTS: 2.0 + PART_WEIGHTS: 0.3 + POINT_REGRESSION_WEIGHTS: 0.1 + DECODER_ON: False +SOLVER: + BASE_LR: 0.002 + MAX_ITER: 130000 + STEPS: (100000, 120000) diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml new file mode 100644 index 0000000..f09d723 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml @@ -0,0 +1,91 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("base_coco_2017_train", "densepose_coco_2014_train") + TEST: ("densepose_chimps",) + CATEGORY_MAPS: + "base_coco_2017_train": + "16": 1 # bird -> person + "17": 1 # cat -> person + "18": 1 # dog -> person + "19": 1 # horse -> person + "20": 1 # sheep -> person + "21": 1 # cow -> person + "22": 1 # elephant -> person + "23": 1 # bear -> person + "24": 1 # zebra -> person + "25": 1 # girafe -> person + "base_coco_2017_val": + "16": 1 # bird -> person + "17": 1 # cat -> person + "18": 1 # dog -> person + "19": 1 # horse -> person + "20": 1 # sheep -> person + "21": 1 # cow -> person + "22": 1 # elephant -> person + "23": 1 # bear -> person + "24": 1 # zebra -> person + "25": 1 # girafe -> person + WHITELISTED_CATEGORIES: + "base_coco_2017_train": + - 1 # person + - 16 # bird + - 17 # cat + - 18 # dog + - 19 # horse + - 20 # sheep + - 21 # cow + - 22 # elephant + - 23 # bear + - 24 # zebra + - 25 # girafe + "base_coco_2017_val": + - 1 # person + - 16 # bird + - 17 # cat + - 18 # dog + - 19 # horse + - 20 # sheep + - 21 # cow + - 22 # elephant + - 23 # bear + - 24 # zebra + - 25 # girafe +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA.yaml new file mode 100644 index 0000000..6296692 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA.yaml @@ -0,0 +1,28 @@ +_BASE_: "Base-RCNN-FPN-Atop10P_CA.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + COARSE_SEGM_TRAINED_BY_MASKS: True + INDEX_WEIGHTS: 1.0 +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + WARMUP_FACTOR: 0.025 + MAX_ITER: 270000 + STEPS: (210000, 250000) diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_coarsesegm.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_coarsesegm.yaml new file mode 100644 index 0000000..033918e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_coarsesegm.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN-Atop10P_CA.yaml" +MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl + RESNETS: + DEPTH: 50 + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + COARSE_SEGM_TRAINED_BY_MASKS: True +BOOTSTRAP_DATASETS: + - DATASET: "chimpnsee" + RATIO: 1.0 + IMAGE_LOADER: + TYPE: "video_keyframe" + SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 + TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 + BATCH_SIZE: 8 + NUM_WORKERS: 1 + INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 + DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_coarse_segm_confidence" + COUNT_PER_CLASS: 8 + FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +BOOTSTRAP_MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 270000 + STEPS: (210000, 250000) diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_finesegm.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_finesegm.yaml new file mode 100644 index 0000000..5814a4a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_finesegm.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN-Atop10P_CA.yaml" +MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl + RESNETS: + DEPTH: 50 + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + COARSE_SEGM_TRAINED_BY_MASKS: True +BOOTSTRAP_DATASETS: + - DATASET: "chimpnsee" + RATIO: 1.0 + IMAGE_LOADER: + TYPE: "video_keyframe" + SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 + TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 + BATCH_SIZE: 8 + NUM_WORKERS: 1 + INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 + DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_fine_segm_confidence" + COUNT_PER_CLASS: 8 + FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +BOOTSTRAP_MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 270000 + STEPS: (210000, 250000) diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform.yaml new file mode 100644 index 0000000..d591ea6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN-Atop10P_CA.yaml" +MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl + RESNETS: + DEPTH: 50 + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + COARSE_SEGM_TRAINED_BY_MASKS: True +BOOTSTRAP_DATASETS: + - DATASET: "chimpnsee" + RATIO: 1.0 + IMAGE_LOADER: + TYPE: "video_keyframe" + SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 + TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 + BATCH_SIZE: 8 + NUM_WORKERS: 1 + INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 + DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_uniform" + COUNT_PER_CLASS: 8 + FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +BOOTSTRAP_MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 270000 + STEPS: (210000, 250000) diff --git a/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uv.yaml b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uv.yaml new file mode 100644 index 0000000..110acff --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uv.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN-Atop10P_CA.yaml" +MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl + RESNETS: + DEPTH: 50 + DENSEPOSE_ON: True + ROI_HEADS: + NAME: "DensePoseROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + NUM_CLASSES: 1 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + SEGM_CONFIDENCE: + ENABLED: True + POINT_REGRESSION_WEIGHTS: 0.0005 + POOLER_TYPE: "ROIAlign" + NUM_COARSE_SEGM_CHANNELS: 2 + COARSE_SEGM_TRAINED_BY_MASKS: True +BOOTSTRAP_DATASETS: + - DATASET: "chimpnsee" + RATIO: 1.0 + IMAGE_LOADER: + TYPE: "video_keyframe" + SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 + TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 + BATCH_SIZE: 8 + NUM_WORKERS: 1 + INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 + DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_UV_confidence" + COUNT_PER_CLASS: 8 + FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +BOOTSTRAP_MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 270000 + STEPS: (210000, 250000) diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_DL_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_DL_instant_test.yaml new file mode 100644 index 0000000..3b43f75 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_DL_instant_test.yaml @@ -0,0 +1,11 @@ +_BASE_: "../../cse/Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100_cse",) + TEST: ("densepose_coco_2014_minival_100_cse",) +SOLVER: + MAX_ITER: 40 + STEPS: (30,) diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_instant_test.yaml new file mode 100644 index 0000000..a2c49a2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_instant_test.yaml @@ -0,0 +1,126 @@ +_BASE_: "../../cse/Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 9 + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseV1ConvXHead" + CSE: + EMBED_LOSS_NAME: "SoftEmbeddingLoss" + EMBEDDING_DIST_GAUSS_SIGMA: 0.1 + EMBEDDERS: + "cat_5001": + TYPE: vertex_feature + NUM_VERTICES: 5001 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cat_5001_256.pkl" + "dog_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_dog_5002_256.pkl" + "sheep_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_sheep_5004_256.pkl" + "horse_5004": + TYPE: vertex_feature + NUM_VERTICES: 5004 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_horse_5004_256.pkl" + "zebra_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_zebra_5002_256.pkl" + "giraffe_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_giraffe_5002_256.pkl" + "elephant_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_elephant_5002_256.pkl" + "cow_5002": + TYPE: vertex_feature + NUM_VERTICES: 5002 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_cow_5002_256.pkl" + "bear_4936": + TYPE: vertex_feature + NUM_VERTICES: 4936 + FEATURE_DIM: 256 + FEATURES_TRAINABLE: False + IS_TRAINABLE: True + INIT_FILE: "https://dl.fbaipublicfiles.com/densepose/data/cse/lbo/phi_bear_4936_256.pkl" +DATASETS: + TRAIN: + - "densepose_lvis_v1_train1" + - "densepose_lvis_v1_train2" + TEST: + - "densepose_lvis_v1_val_animals_100" + WHITELISTED_CATEGORIES: + "densepose_lvis_v1_train1": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_train2": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + "densepose_lvis_v1_val_animals_100": + - 943 # sheep + - 1202 # zebra + - 569 # horse + - 496 # giraffe + - 422 # elephant + - 80 # cow + - 76 # bear + - 225 # cat + - 378 # dog + CLASS_TO_MESH_NAME_MAPPING: + "0": "bear_4936" + "1": "cow_5002" + "2": "cat_5001" + "3": "dog_5002" + "4": "elephant_5002" + "5": "giraffe_5002" + "6": "horse_5004" + "7": "sheep_5004" + "8": "zebra_5002" +SOLVER: + MAX_ITER: 40 + STEPS: (30,) diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_HRFPN_HRNet_w32_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_HRFPN_HRNet_w32_instant_test.yaml new file mode 100644 index 0000000..95677ce --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_HRFPN_HRNet_w32_instant_test.yaml @@ -0,0 +1,8 @@ +_BASE_: "../HRNet/densepose_rcnn_HRFPN_HRNet_w32_s1x.yaml" +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100",) + TEST: ("densepose_coco_2014_minival_100",) +SOLVER: + MAX_ITER: 40 + STEPS: (30,) + IMS_PER_BATCH: 2 diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_DL_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_DL_instant_test.yaml new file mode 100644 index 0000000..b90989e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_DL_instant_test.yaml @@ -0,0 +1,11 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_DENSEPOSE_HEAD: + NAME: "DensePoseDeepLabHead" +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100",) + TEST: ("densepose_coco_2014_minival_100",) +SOLVER: + MAX_ITER: 40 + STEPS: (30,) diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_TTA_inference_acc_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_TTA_inference_acc_test.yaml new file mode 100644 index 0000000..b124da1 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_TTA_inference_acc_test.yaml @@ -0,0 +1,13 @@ +_BASE_: "../densepose_rcnn_R_50_FPN_s1x.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl" +DATASETS: + TRAIN: () + TEST: ("densepose_coco_2014_minival_100",) +TEST: + AUG: + ENABLED: True + MIN_SIZES: (400, 500, 600, 700, 800, 900, 1000, 1100, 1200) + MAX_SIZE: 4000 + FLIP: True + EXPECTED_RESULTS: [["bbox_TTA", "AP", 61.74, 0.03], ["densepose_gps_TTA", "AP", 60.22, 0.03], ["densepose_gpsm_TTA", "AP", 63.59, 0.03]] diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC1_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC1_instant_test.yaml new file mode 100644 index 0000000..f0fe611 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC1_instant_test.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "iid_iso" + POINT_REGRESSION_WEIGHTS: 0.0005 +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100",) + TEST: ("densepose_coco_2014_minival_100",) +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 40 + STEPS: (30,) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC2_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC2_instant_test.yaml new file mode 100644 index 0000000..f0d9358 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_WC2_instant_test.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_DENSEPOSE_HEAD: + UV_CONFIDENCE: + ENABLED: True + TYPE: "indep_aniso" + POINT_REGRESSION_WEIGHTS: 0.0005 +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100",) + TEST: ("densepose_coco_2014_minival_100",) +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + MAX_ITER: 40 + STEPS: (30,) + WARMUP_FACTOR: 0.025 diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_inference_acc_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_inference_acc_test.yaml new file mode 100644 index 0000000..d607c98 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_inference_acc_test.yaml @@ -0,0 +1,8 @@ +_BASE_: "../densepose_rcnn_R_50_FPN_s1x.yaml" +MODEL: + WEIGHTS: "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl" +DATASETS: + TRAIN: () + TEST: ("densepose_coco_2014_minival_100",) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 59.27, 0.025], ["densepose_gps", "AP", 60.11, 0.02], ["densepose_gpsm", "AP", 64.09, 0.02]] diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_instant_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_instant_test.yaml new file mode 100644 index 0000000..057c876 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_instant_test.yaml @@ -0,0 +1,9 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +DATASETS: + TRAIN: ("densepose_coco_2014_minival_100",) + TEST: ("densepose_coco_2014_minival_100",) +SOLVER: + MAX_ITER: 40 + STEPS: (30,) diff --git a/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_training_acc_test.yaml b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_training_acc_test.yaml new file mode 100644 index 0000000..0053c9d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/configs/quick_schedules/densepose_rcnn_R_50_FPN_training_acc_test.yaml @@ -0,0 +1,18 @@ +_BASE_: "../Base-DensePose-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + ROI_HEADS: + NUM_CLASSES: 1 +DATASETS: + TRAIN: ("densepose_coco_2014_minival",) + TEST: ("densepose_coco_2014_minival",) +SOLVER: + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: norm + CLIP_VALUE: 1.0 + MAX_ITER: 6000 + STEPS: (5500, 5800) +TEST: + EXPECTED_RESULTS: [["bbox", "AP", 76.2477, 1.0], ["densepose_gps", "AP", 79.6090, 1.5], ["densepose_gpsm", "AP", 80.0061, 1.5]] + diff --git a/data_processing/detectron2/projects/DensePose/densepose/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/__init__.py new file mode 100644 index 0000000..b50a3da --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .data.datasets import builtin # just to register data +from .converters import builtin as builtin_converters # register converters +from .config import ( + add_densepose_config, + add_densepose_head_config, + add_hrnet_config, + add_dataset_category_config, + add_bootstrap_config, + load_bootstrap_config, +) +from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData +from .evaluation import DensePoseCOCOEvaluator +from .modeling.roi_heads import DensePoseROIHeads +from .modeling.test_time_augmentation import ( + DensePoseGeneralizedRCNNWithTTA, + DensePoseDatasetMapperTTA, +) +from .utils.transform import load_from_cfg +from .modeling.hrfpn import build_hrfpn_backbone diff --git a/data_processing/detectron2/projects/DensePose/densepose/config.py b/data_processing/detectron2/projects/DensePose/densepose/config.py new file mode 100644 index 0000000..2a06a09 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/config.py @@ -0,0 +1,277 @@ +# -*- coding = utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# pyre-ignore-all-errors + +from detectron2.config import CfgNode as CN + + +def add_dataset_category_config(cfg: CN) -> None: + """ + Add config for additional category-related dataset options + - category whitelisting + - category mapping + """ + _C = cfg + _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True) + _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True) + # class to mesh mapping + _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True) + + +def add_evaluation_config(cfg: CN) -> None: + _C = cfg + _C.DENSEPOSE_EVALUATION = CN() + # evaluator type, possible values: + # - "iou": evaluator for models that produce iou data + # - "cse": evaluator for models that produce cse data + _C.DENSEPOSE_EVALUATION.TYPE = "iou" + # storage for DensePose results, possible values: + # - "none": no explicit storage, all the results are stored in the + # dictionary with predictions, memory intensive; + # historically the default storage type + # - "ram": RAM storage, uses per-process RAM storage, which is + # reduced to a single process storage on later stages, + # less memory intensive + # - "file": file storage, uses per-process file-based storage, + # the least memory intensive, but may create bottlenecks + # on file system accesses + _C.DENSEPOSE_EVALUATION.STORAGE = "none" + # minimum threshold for IOU values: the lower its values is, + # the more matches are produced (and the higher the AP score) + _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5 + # Non-distributed inference is slower (at inference time) but can avoid RAM OOM + _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True + # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context + _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False + # meshes to compute mesh alignment for + _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = [] + + +def add_bootstrap_config(cfg: CN) -> None: + """ """ + _C = cfg + _C.BOOTSTRAP_DATASETS = [] + _C.BOOTSTRAP_MODEL = CN() + _C.BOOTSTRAP_MODEL.WEIGHTS = "" + _C.BOOTSTRAP_MODEL.DEVICE = "cuda" + + +def get_bootstrap_dataset_config() -> CN: + _C = CN() + _C.DATASET = "" + # ratio used to mix data loaders + _C.RATIO = 0.1 + # image loader + _C.IMAGE_LOADER = CN(new_allowed=True) + _C.IMAGE_LOADER.TYPE = "" + _C.IMAGE_LOADER.BATCH_SIZE = 4 + _C.IMAGE_LOADER.NUM_WORKERS = 4 + _C.IMAGE_LOADER.CATEGORIES = [] + _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000 + _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True) + # inference + _C.INFERENCE = CN() + # batch size for model inputs + _C.INFERENCE.INPUT_BATCH_SIZE = 4 + # batch size to group model outputs + _C.INFERENCE.OUTPUT_BATCH_SIZE = 2 + # sampled data + _C.DATA_SAMPLER = CN(new_allowed=True) + _C.DATA_SAMPLER.TYPE = "" + _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False + # filter + _C.FILTER = CN(new_allowed=True) + _C.FILTER.TYPE = "" + return _C + + +def load_bootstrap_config(cfg: CN) -> None: + """ + Bootstrap datasets are given as a list of `dict` that are not automatically + converted into CfgNode. This method processes all bootstrap dataset entries + and ensures that they are in CfgNode format and comply with the specification + """ + if not cfg.BOOTSTRAP_DATASETS: + return + + bootstrap_datasets_cfgnodes = [] + for dataset_cfg in cfg.BOOTSTRAP_DATASETS: + _C = get_bootstrap_dataset_config().clone() + _C.merge_from_other_cfg(CN(dataset_cfg)) + bootstrap_datasets_cfgnodes.append(_C) + cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes + + +def add_densepose_head_cse_config(cfg: CN) -> None: + """ + Add configuration options for Continuous Surface Embeddings (CSE) + """ + _C = cfg + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN() + # Dimensionality D of the embedding space + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16 + # Embedder specifications for various mesh IDs + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True) + # normalization coefficient for embedding distances + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01 + # normalization coefficient for geodesic distances + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01 + # embedding loss weight + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6 + # embedding loss name, currently the following options are supported: + # - EmbeddingLoss: cross-entropy on vertex labels + # - SoftEmbeddingLoss: cross-entropy on vertex label combined with + # Gaussian penalty on distance between vertices + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss" + # optimizer hyperparameters + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0 + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0 + # Shape to shape cycle consistency loss parameters: + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False}) + # shape to shape cycle consistency loss weight + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025 + # norm type used for loss computation + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2 + # normalization term for embedding similarity matrices + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05 + # maximum number of vertices to include into shape to shape cycle loss + # if negative or zero, all vertices are considered + # if positive, random subset of vertices of given size is considered + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936 + # Pixel to shape cycle consistency loss parameters: + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False}) + # pixel to shape cycle consistency loss weight + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001 + # norm type used for loss computation + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2 + # map images to all meshes and back (if false, use only gt meshes from the batch) + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False + # Randomly select at most this number of pixels from every instance + # if negative or zero, all vertices are considered + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100 + # normalization factor for pixel to pixel distances (higher value = smoother distribution) + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0 + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05 + _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05 + + +def add_densepose_head_config(cfg: CN) -> None: + """ + Add config for densepose head. + """ + _C = cfg + + _C.MODEL.DENSEPOSE_ON = True + + _C.MODEL.ROI_DENSEPOSE_HEAD = CN() + _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = "" + _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8 + # Number of parts used for point labels + _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24 + _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4 + _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512 + _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3 + _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2 + _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112 + _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2" + _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28 + _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2 + _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2 + # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) + _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7 + # Loss weights for annotation masks.(14 Parts) + _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0 + # Loss weights for surface parts. (24 Parts) + _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0 + # Loss weights for UV regression. + _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01 + # Coarse segmentation is trained using instance segmentation task data + _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False + # For Decoder + _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True + _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256 + _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256 + _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = "" + _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4 + # For DeepLab head + _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN() + _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN" + _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0 + # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY + # Some registered predictors: + # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts + # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates + # and associated confidences for predefined charts (default) + # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings + # and associated confidences for CSE + _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor" + # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY + # Some registered losses: + # "DensePoseChartLoss": loss for chart-based models that estimate + # segmentation and UV coordinates + # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate + # segmentation, UV coordinates and the corresponding confidences (default) + _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss" + # Confidences + # Enable learning UV confidences (variances) along with the actual values + _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False}) + # UV confidence lower bound + _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01 + # Enable learning segmentation confidences (variances) along with the actual values + _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False}) + # Segmentation confidence lower bound + _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01 + # Statistical model type for confidence learning, possible values: + # - "iid_iso": statistically independent identically distributed residuals + # with isotropic covariance + # - "indep_aniso": statistically independent residuals with anisotropic + # covariances + _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso" + # List of angles for rotation in data augmentation during training + _C.INPUT.ROTATION_ANGLES = [0] + _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA + + add_densepose_head_cse_config(cfg) + + +def add_hrnet_config(cfg: CN) -> None: + """ + Add config for HRNet backbone. + """ + _C = cfg + + # For HigherHRNet w32 + _C.MODEL.HRNET = CN() + _C.MODEL.HRNET.STEM_INPLANES = 64 + _C.MODEL.HRNET.STAGE2 = CN() + _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1 + _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2 + _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC" + _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4] + _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64] + _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM" + _C.MODEL.HRNET.STAGE3 = CN() + _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4 + _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3 + _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC" + _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4] + _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128] + _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM" + _C.MODEL.HRNET.STAGE4 = CN() + _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3 + _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4 + _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC" + _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] + _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] + _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM" + + _C.MODEL.HRNET.HRFPN = CN() + _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256 + + +def add_densepose_config(cfg: CN) -> None: + add_densepose_head_config(cfg) + add_hrnet_config(cfg) + add_bootstrap_config(cfg) + add_dataset_category_config(cfg) + add_evaluation_config(cfg) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/converters/__init__.py new file mode 100644 index 0000000..930339e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .hflip import HFlipConverter +from .to_mask import ToMaskConverter +from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences +from .segm_to_mask import ( + predictor_output_with_fine_and_coarse_segm_to_mask, + predictor_output_with_coarse_segm_to_mask, + resample_fine_and_coarse_segm_to_bbox, +) +from .chart_output_to_chart_result import ( + densepose_chart_predictor_output_to_result, + densepose_chart_predictor_output_to_result_with_confidences, +) +from .chart_output_hflip import densepose_chart_predictor_output_hflip diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/base.py b/data_processing/detectron2/projects/DensePose/densepose/converters/base.py new file mode 100644 index 0000000..c9dbe56 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/base.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any, Tuple, Type +import torch + + +class BaseConverter: + """ + Converter base class to be reused by various converters. + Converter allows one to convert data from various source types to a particular + destination type. Each source type needs to register its converter. The + registration for each source type is valid for all descendants of that type. + """ + + @classmethod + def register(cls, from_type: Type, converter: Any = None): + """ + Registers a converter for the specified type. + Can be used as a decorator (if converter is None), or called as a method. + + Args: + from_type (type): type to register the converter for; + all instances of this type will use the same converter + converter (callable): converter to be registered for the given + type; if None, this method is assumed to be a decorator for the converter + """ + + if converter is not None: + cls._do_register(from_type, converter) + + def wrapper(converter: Any) -> Any: + cls._do_register(from_type, converter) + return converter + + return wrapper + + @classmethod + def _do_register(cls, from_type: Type, converter: Any): + cls.registry[from_type] = converter # pyre-ignore[16] + + @classmethod + def _lookup_converter(cls, from_type: Type) -> Any: + """ + Perform recursive lookup for the given type + to find registered converter. If a converter was found for some base + class, it gets registered for this class to save on further lookups. + + Args: + from_type: type for which to find a converter + Return: + callable or None - registered converter or None + if no suitable entry was found in the registry + """ + if from_type in cls.registry: # pyre-ignore[16] + return cls.registry[from_type] + for base in from_type.__bases__: + converter = cls._lookup_converter(base) + if converter is not None: + cls._do_register(from_type, converter) + return converter + return None + + @classmethod + def convert(cls, instance: Any, *args, **kwargs): + """ + Convert an instance to the destination type using some registered + converter. Does recursive lookup for base classes, so there's no need + for explicit registration for derived classes. + + Args: + instance: source instance to convert to the destination type + Return: + An instance of the destination type obtained from the source instance + Raises KeyError, if no suitable converter found + """ + instance_type = type(instance) + converter = cls._lookup_converter(instance_type) + if converter is None: + if cls.dst_type is None: # pyre-ignore[16] + output_type_str = "itself" + else: + output_type_str = cls.dst_type + raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}") + return converter(instance, *args, **kwargs) + + +IntTupleBox = Tuple[int, int, int, int] + + +def make_int_box(box: torch.Tensor) -> IntTupleBox: + int_box = [0, 0, 0, 0] + int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist()) + return int_box[0], int_box[1], int_box[2], int_box[3] diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/builtin.py b/data_processing/detectron2/projects/DensePose/densepose/converters/builtin.py new file mode 100644 index 0000000..3bd48f8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/builtin.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput +from . import ( + HFlipConverter, + ToChartResultConverter, + ToChartResultConverterWithConfidences, + ToMaskConverter, + densepose_chart_predictor_output_hflip, + densepose_chart_predictor_output_to_result, + densepose_chart_predictor_output_to_result_with_confidences, + predictor_output_with_coarse_segm_to_mask, + predictor_output_with_fine_and_coarse_segm_to_mask, +) + +ToMaskConverter.register( + DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask +) +ToMaskConverter.register( + DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask +) + +ToChartResultConverter.register( + DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result +) + +ToChartResultConverterWithConfidences.register( + DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences +) + +HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_hflip.py b/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_hflip.py new file mode 100644 index 0000000..17d2948 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_hflip.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from dataclasses import fields +import torch + +from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData + + +def densepose_chart_predictor_output_hflip( + densepose_predictor_output: DensePoseChartPredictorOutput, + transform_data: DensePoseTransformData, +) -> DensePoseChartPredictorOutput: + """ + Change to take into account a Horizontal flip. + """ + if len(densepose_predictor_output) > 0: + + PredictorOutput = type(densepose_predictor_output) + output_dict = {} + + for field in fields(densepose_predictor_output): + field_value = getattr(densepose_predictor_output, field.name) + # flip tensors + if isinstance(field_value, torch.Tensor): + setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3])) + + densepose_predictor_output = _flip_iuv_semantics_tensor( + densepose_predictor_output, transform_data + ) + densepose_predictor_output = _flip_segm_semantics_tensor( + densepose_predictor_output, transform_data + ) + + for field in fields(densepose_predictor_output): + output_dict[field.name] = getattr(densepose_predictor_output, field.name) + + return PredictorOutput(**output_dict) + else: + return densepose_predictor_output + + +def _flip_iuv_semantics_tensor( + densepose_predictor_output: DensePoseChartPredictorOutput, + dp_transform_data: DensePoseTransformData, +) -> DensePoseChartPredictorOutput: + point_label_symmetries = dp_transform_data.point_label_symmetries + uv_symmetries = dp_transform_data.uv_symmetries + + N, C, H, W = densepose_predictor_output.u.shape + u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long() + v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long() + Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[ + None, :, None, None + ].expand(N, C - 1, H, W) + densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc] + densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc] + + for el in ["fine_segm", "u", "v"]: + densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][ + :, point_label_symmetries, :, : + ] + return densepose_predictor_output + + +def _flip_segm_semantics_tensor( + densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data +): + if densepose_predictor_output.coarse_segm.shape[1] > 2: + densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[ + :, dp_transform_data.mask_label_symmetries, :, : + ] + return densepose_predictor_output diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_to_chart_result.py b/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_to_chart_result.py new file mode 100644 index 0000000..4248f6c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/chart_output_to_chart_result.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Dict +import torch +from torch.nn import functional as F + +from detectron2.structures.boxes import Boxes, BoxMode + +from ..structures import ( + DensePoseChartPredictorOutput, + DensePoseChartResult, + DensePoseChartResultWithConfidences, +) +from . import resample_fine_and_coarse_segm_to_bbox +from .base import IntTupleBox, make_int_box + + +def resample_uv_tensors_to_bbox( + u: torch.Tensor, + v: torch.Tensor, + labels: torch.Tensor, + box_xywh_abs: IntTupleBox, +) -> torch.Tensor: + """ + Resamples U and V coordinate estimates for the given bounding box + + Args: + u (tensor [1, C, H, W] of float): U coordinates + v (tensor [1, C, H, W] of float): V coordinates + labels (tensor [H, W] of long): labels obtained by resampling segmentation + outputs for the given bounding box + box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs + Return: + Resampled U and V coordinates - a tensor [2, H, W] of float + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False) + v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False) + uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device) + for part_id in range(1, u_bbox.size(1)): + uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id] + uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id] + return uv + + +def resample_uv_to_bbox( + predictor_output: DensePoseChartPredictorOutput, + labels: torch.Tensor, + box_xywh_abs: IntTupleBox, +) -> torch.Tensor: + """ + Resamples U and V coordinate estimates for the given bounding box + + Args: + predictor_output (DensePoseChartPredictorOutput): DensePose predictor + output to be resampled + labels (tensor [H, W] of long): labels obtained by resampling segmentation + outputs for the given bounding box + box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs + Return: + Resampled U and V coordinates - a tensor [2, H, W] of float + """ + return resample_uv_tensors_to_bbox( + predictor_output.u, + predictor_output.v, + labels, + box_xywh_abs, + ) + + +def densepose_chart_predictor_output_to_result( + predictor_output: DensePoseChartPredictorOutput, boxes: Boxes +) -> DensePoseChartResult: + """ + Convert densepose chart predictor outputs to results + + Args: + predictor_output (DensePoseChartPredictorOutput): DensePose predictor + output to be converted to results, must contain only 1 output + boxes (Boxes): bounding box that corresponds to the predictor output, + must contain only 1 bounding box + Return: + DensePose chart-based result (DensePoseChartResult) + """ + assert len(predictor_output) == 1 and len(boxes) == 1, ( + f"Predictor output to result conversion can operate only single outputs" + f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes" + ) + + boxes_xyxy_abs = boxes.tensor.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + box_xywh = make_int_box(boxes_xywh_abs[0]) + + labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0) + uv = resample_uv_to_bbox(predictor_output, labels, box_xywh) + return DensePoseChartResult(labels=labels, uv=uv) + + +def resample_confidences_to_bbox( + predictor_output: DensePoseChartPredictorOutput, + labels: torch.Tensor, + box_xywh_abs: IntTupleBox, +) -> Dict[str, torch.Tensor]: + """ + Resamples confidences for the given bounding box + + Args: + predictor_output (DensePoseChartPredictorOutput): DensePose predictor + output to be resampled + labels (tensor [H, W] of long): labels obtained by resampling segmentation + outputs for the given bounding box + box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs + Return: + Resampled confidences - a dict of [H, W] tensors of float + """ + + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + + confidence_names = [ + "sigma_1", + "sigma_2", + "kappa_u", + "kappa_v", + "fine_segm_confidence", + "coarse_segm_confidence", + ] + confidence_results = {key: None for key in confidence_names} + confidence_names = [ + key for key in confidence_names if getattr(predictor_output, key) is not None + ] + confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device) + + # assign data from channels that correspond to the labels + for key in confidence_names: + resampled_confidence = F.interpolate( + getattr(predictor_output, key), + (h, w), + mode="bilinear", + align_corners=False, + ) + result = confidence_base.clone() + for part_id in range(1, predictor_output.u.size(1)): + if resampled_confidence.size(1) != predictor_output.u.size(1): + # confidence is not part-based, don't try to fill it part by part + continue + result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id] + + if resampled_confidence.size(1) != predictor_output.u.size(1): + # confidence is not part-based, fill the data with the first channel + # (targeted for segmentation confidences that have only 1 channel) + result = resampled_confidence[0, 0] + + confidence_results[key] = result + + return confidence_results # pyre-ignore[7] + + +def densepose_chart_predictor_output_to_result_with_confidences( + predictor_output: DensePoseChartPredictorOutput, boxes: Boxes +) -> DensePoseChartResultWithConfidences: + """ + Convert densepose chart predictor outputs to results + + Args: + predictor_output (DensePoseChartPredictorOutput): DensePose predictor + output with confidences to be converted to results, must contain only 1 output + boxes (Boxes): bounding box that corresponds to the predictor output, + must contain only 1 bounding box + Return: + DensePose chart-based result with confidences (DensePoseChartResultWithConfidences) + """ + assert len(predictor_output) == 1 and len(boxes) == 1, ( + f"Predictor output to result conversion can operate only single outputs" + f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes" + ) + + boxes_xyxy_abs = boxes.tensor.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + box_xywh = make_int_box(boxes_xywh_abs[0]) + + labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0) + uv = resample_uv_to_bbox(predictor_output, labels, box_xywh) + confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh) + return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/hflip.py b/data_processing/detectron2/projects/DensePose/densepose/converters/hflip.py new file mode 100644 index 0000000..6df1442 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/hflip.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any + +from .base import BaseConverter + + +class HFlipConverter(BaseConverter): + """ + Converts various DensePose predictor outputs to DensePose results. + Each DensePose predictor output type has to register its convertion strategy. + """ + + registry = {} + dst_type = None + + @classmethod + # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter` + # inconsistently. + def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs): + """ + Performs an horizontal flip on DensePose predictor outputs. + Does recursive lookup for base classes, so there's no need + for explicit registration for derived classes. + + Args: + predictor_outputs: DensePose predictor output to be converted to BitMasks + transform_data: Anything useful for the flip + Return: + An instance of the same type as predictor_outputs + """ + return super(HFlipConverter, cls).convert( + predictor_outputs, transform_data, *args, **kwargs + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/segm_to_mask.py b/data_processing/detectron2/projects/DensePose/densepose/converters/segm_to_mask.py new file mode 100644 index 0000000..6433d5d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/segm_to_mask.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any +import torch +from torch.nn import functional as F + +from detectron2.structures import BitMasks, Boxes, BoxMode + +from .base import IntTupleBox, make_int_box +from .to_mask import ImageSizeType + + +def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox): + """ + Resample coarse segmentation tensor to the given + bounding box and derive labels for each pixel of the bounding box + + Args: + coarse_segm: float tensor of shape [1, K, Hout, Wout] + box_xywh_abs (tuple of 4 int): bounding box given by its upper-left + corner coordinates, width (W) and height (H) + Return: + Labels for each pixel of the bounding box, a long tensor of size [1, H, W] + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1) + return labels + + +def resample_fine_and_coarse_segm_tensors_to_bbox( + fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox +): + """ + Resample fine and coarse segmentation tensors to the given + bounding box and derive labels for each pixel of the bounding box + + Args: + fine_segm: float tensor of shape [1, C, Hout, Wout] + coarse_segm: float tensor of shape [1, K, Hout, Wout] + box_xywh_abs (tuple of 4 int): bounding box given by its upper-left + corner coordinates, width (W) and height (H) + Return: + Labels for each pixel of the bounding box, a long tensor of size [1, H, W] + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + # coarse segmentation + coarse_segm_bbox = F.interpolate( + coarse_segm, + (h, w), + mode="bilinear", + align_corners=False, + ).argmax(dim=1) + # combined coarse and fine segmentation + labels = ( + F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1) + * (coarse_segm_bbox > 0).long() + ) + return labels + + +def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox): + """ + Resample fine and coarse segmentation outputs from a predictor to the given + bounding box and derive labels for each pixel of the bounding box + + Args: + predictor_output: DensePose predictor output that contains segmentation + results to be resampled + box_xywh_abs (tuple of 4 int): bounding box given by its upper-left + corner coordinates, width (W) and height (H) + Return: + Labels for each pixel of the bounding box, a long tensor of size [1, H, W] + """ + return resample_fine_and_coarse_segm_tensors_to_bbox( + predictor_output.fine_segm, + predictor_output.coarse_segm, + box_xywh_abs, + ) + + +def predictor_output_with_coarse_segm_to_mask( + predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType +) -> BitMasks: + """ + Convert predictor output with coarse and fine segmentation to a mask. + Assumes that predictor output has the following attributes: + - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation + unnormalized scores for N instances; D is the number of coarse + segmentation labels, H and W is the resolution of the estimate + + Args: + predictor_output: DensePose predictor output to be converted to mask + boxes (Boxes): bounding boxes that correspond to the DensePose + predictor outputs + image_size_hw (tuple [int, int]): image height Himg and width Wimg + Return: + BitMasks that contain a bool tensor of size [N, Himg, Wimg] with + a mask of the size of the image for each instance + """ + H, W = image_size_hw + boxes_xyxy_abs = boxes.tensor.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + N = len(boxes_xywh_abs) + masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device) + for i in range(len(boxes_xywh_abs)): + box_xywh = make_int_box(boxes_xywh_abs[i]) + box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh) + x, y, w, h = box_xywh + masks[i, y : y + h, x : x + w] = box_mask + + return BitMasks(masks) + + +def predictor_output_with_fine_and_coarse_segm_to_mask( + predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType +) -> BitMasks: + """ + Convert predictor output with coarse and fine segmentation to a mask. + Assumes that predictor output has the following attributes: + - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation + unnormalized scores for N instances; D is the number of coarse + segmentation labels, H and W is the resolution of the estimate + - fine_segm (tensor of size [N, C, H, W]): fine segmentation + unnormalized scores for N instances; C is the number of fine + segmentation labels, H and W is the resolution of the estimate + + Args: + predictor_output: DensePose predictor output to be converted to mask + boxes (Boxes): bounding boxes that correspond to the DensePose + predictor outputs + image_size_hw (tuple [int, int]): image height Himg and width Wimg + Return: + BitMasks that contain a bool tensor of size [N, Himg, Wimg] with + a mask of the size of the image for each instance + """ + H, W = image_size_hw + boxes_xyxy_abs = boxes.tensor.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + N = len(boxes_xywh_abs) + masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device) + for i in range(len(boxes_xywh_abs)): + box_xywh = make_int_box(boxes_xywh_abs[i]) + labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh) + x, y, w, h = box_xywh + masks[i, y : y + h, x : x + w] = labels_i > 0 + return BitMasks(masks) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/to_chart_result.py b/data_processing/detectron2/projects/DensePose/densepose/converters/to_chart_result.py new file mode 100644 index 0000000..3eabd26 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/to_chart_result.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any + +from detectron2.structures import Boxes + +from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences +from .base import BaseConverter + + +class ToChartResultConverter(BaseConverter): + """ + Converts various DensePose predictor outputs to DensePose results. + Each DensePose predictor output type has to register its convertion strategy. + """ + + registry = {} + dst_type = DensePoseChartResult + + @classmethod + # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter` + # inconsistently. + def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult: + """ + Convert DensePose predictor outputs to DensePoseResult using some registered + converter. Does recursive lookup for base classes, so there's no need + for explicit registration for derived classes. + + Args: + densepose_predictor_outputs: DensePose predictor output to be + converted to BitMasks + boxes (Boxes): bounding boxes that correspond to the DensePose + predictor outputs + Return: + An instance of DensePoseResult. If no suitable converter was found, raises KeyError + """ + return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs) + + +class ToChartResultConverterWithConfidences(BaseConverter): + """ + Converts various DensePose predictor outputs to DensePose results. + Each DensePose predictor output type has to register its convertion strategy. + """ + + registry = {} + dst_type = DensePoseChartResultWithConfidences + + @classmethod + # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter` + # inconsistently. + def convert( + cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs + ) -> DensePoseChartResultWithConfidences: + """ + Convert DensePose predictor outputs to DensePoseResult with confidences + using some registered converter. Does recursive lookup for base classes, + so there's no need for explicit registration for derived classes. + + Args: + densepose_predictor_outputs: DensePose predictor output with confidences + to be converted to BitMasks + boxes (Boxes): bounding boxes that correspond to the DensePose + predictor outputs + Return: + An instance of DensePoseResult. If no suitable converter was found, raises KeyError + """ + return super(ToChartResultConverterWithConfidences, cls).convert( + predictor_outputs, boxes, *args, **kwargs + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/converters/to_mask.py b/data_processing/detectron2/projects/DensePose/densepose/converters/to_mask.py new file mode 100644 index 0000000..a57fd71 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/converters/to_mask.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any, Tuple + +from detectron2.structures import BitMasks, Boxes + +from .base import BaseConverter + +ImageSizeType = Tuple[int, int] + + +class ToMaskConverter(BaseConverter): + """ + Converts various DensePose predictor outputs to masks + in bit mask format (see `BitMasks`). Each DensePose predictor output type + has to register its convertion strategy. + """ + + registry = {} + dst_type = BitMasks + + @classmethod + # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter` + # inconsistently. + def convert( + cls, + densepose_predictor_outputs: Any, + boxes: Boxes, + image_size_hw: ImageSizeType, + *args, + **kwargs + ) -> BitMasks: + """ + Convert DensePose predictor outputs to BitMasks using some registered + converter. Does recursive lookup for base classes, so there's no need + for explicit registration for derived classes. + + Args: + densepose_predictor_outputs: DensePose predictor output to be + converted to BitMasks + boxes (Boxes): bounding boxes that correspond to the DensePose + predictor outputs + image_size_hw (tuple [int, int]): image height and width + Return: + An instance of `BitMasks`. If no suitable converter was found, raises KeyError + """ + return super(ToMaskConverter, cls).convert( + densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/__init__.py new file mode 100644 index 0000000..bf21ba7 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .meshes import builtin +from .build import ( + build_detection_test_loader, + build_detection_train_loader, + build_combined_loader, + build_frame_selector, + build_inference_based_loaders, + has_inference_based_loaders, + BootstrapDatasetFactoryCatalog, +) +from .combined_loader import CombinedDataLoader +from .dataset_mapper import DatasetMapper +from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter +from .image_list_dataset import ImageListDataset +from .utils import is_relative_local_path, maybe_prepend_base_path + +# ensure the builtin datasets are registered +from . import datasets + +# ensure the bootstrap datasets builders are registered +from . import build + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/build.py b/data_processing/detectron2/projects/DensePose/densepose/data/build.py new file mode 100644 index 0000000..39edbd8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/build.py @@ -0,0 +1,736 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import itertools +import logging +import numpy as np +from collections import UserDict, defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple +import torch +from torch.utils.data.dataset import Dataset + +from detectron2.config import CfgNode +from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader +from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader +from detectron2.data.build import ( + load_proposals_into_dataset, + print_instances_class_histogram, + trivial_batch_collator, + worker_init_reset_seed, +) +from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog +from detectron2.data.samplers import TrainingSampler +from detectron2.utils.comm import get_world_size + +from densepose.config import get_bootstrap_dataset_config +from densepose.modeling import build_densepose_embedder + +from .combined_loader import CombinedDataLoader, Loader +from .dataset_mapper import DatasetMapper +from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK +from .datasets.dataset_type import DatasetType +from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter +from .samplers import ( + DensePoseConfidenceBasedSampler, + DensePoseCSEConfidenceBasedSampler, + DensePoseCSEUniformSampler, + DensePoseUniformSampler, + MaskFromDensePoseSampler, + PredictionToGroundTruthSampler, +) +from .transform import ImageResizeTransform +from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping +from .video import ( + FirstKFramesSelector, + FrameSelectionStrategy, + LastKFramesSelector, + RandomKFramesSelector, + VideoKeyframeDataset, + video_list_from_file, +) + +__all__ = ["build_detection_train_loader", "build_detection_test_loader"] + + +Instance = Dict[str, Any] +InstancePredicate = Callable[[Instance], bool] + + +def _compute_num_images_per_worker(cfg: CfgNode) -> int: + num_workers = get_world_size() + images_per_batch = cfg.SOLVER.IMS_PER_BATCH + assert ( + images_per_batch % num_workers == 0 + ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( + images_per_batch, num_workers + ) + assert ( + images_per_batch >= num_workers + ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( + images_per_batch, num_workers + ) + images_per_worker = images_per_batch // num_workers + return images_per_worker + + +def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None: + meta = MetadataCatalog.get(dataset_name) + for dataset_dict in dataset_dicts: + for ann in dataset_dict["annotations"]: + ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]] + + +@dataclass +class _DatasetCategory: + """ + Class representing category data in a dataset: + - id: category ID, as specified in the dataset annotations file + - name: category name, as specified in the dataset annotations file + - mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option) + - mapped_name: category name after applying category maps + - dataset_name: dataset in which the category is defined + + For example, when training models in a class-agnostic manner, one could take LVIS 1.0 + dataset and map the animal categories to the same category as human data from COCO: + id = 225 + name = "cat" + mapped_id = 1 + mapped_name = "person" + dataset_name = "lvis_v1_animals_dp_train" + """ + + id: int + name: str + mapped_id: int + mapped_name: str + dataset_name: str + + +_MergedCategoriesT = Dict[int, List[_DatasetCategory]] + + +def _add_category_id_to_contiguous_id_maps_to_metadata( + merged_categories: _MergedCategoriesT, +) -> None: + merged_categories_per_dataset = {} + for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())): + for cat in merged_categories[cat_id]: + if cat.dataset_name not in merged_categories_per_dataset: + merged_categories_per_dataset[cat.dataset_name] = defaultdict(list) + merged_categories_per_dataset[cat.dataset_name][cat_id].append( + ( + contiguous_cat_id, + cat, + ) + ) + + logger = logging.getLogger(__name__) + for dataset_name, merged_categories in merged_categories_per_dataset.items(): + meta = MetadataCatalog.get(dataset_name) + if not hasattr(meta, "thing_classes"): + meta.thing_classes = [] + meta.thing_dataset_id_to_contiguous_id = {} + meta.thing_dataset_id_to_merged_id = {} + else: + meta.thing_classes.clear() + meta.thing_dataset_id_to_contiguous_id.clear() + meta.thing_dataset_id_to_merged_id.clear() + logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:") + for _cat_id, categories in sorted(merged_categories.items()): + added_to_thing_classes = False + for contiguous_cat_id, cat in categories: + if not added_to_thing_classes: + meta.thing_classes.append(cat.mapped_name) + added_to_thing_classes = True + meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id + meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id + logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}") + + +def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]: + def has_annotations(instance: Instance) -> bool: + return "annotations" in instance + + def has_only_crowd_anotations(instance: Instance) -> bool: + for ann in instance["annotations"]: + if ann.get("is_crowd", 0) == 0: + return False + return True + + def general_keep_instance_predicate(instance: Instance) -> bool: + return has_annotations(instance) and not has_only_crowd_anotations(instance) + + if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS: + return None + return general_keep_instance_predicate + + +def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]: + + min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + + def has_sufficient_num_keypoints(instance: Instance) -> bool: + num_kpts = sum( + (np.array(ann["keypoints"][2::3]) > 0).sum() + for ann in instance["annotations"] + if "keypoints" in ann + ) + return num_kpts >= min_num_keypoints + + if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0): + return has_sufficient_num_keypoints + return None + + +def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]: + if not cfg.MODEL.MASK_ON: + return None + + def has_mask_annotations(instance: Instance) -> bool: + return any("segmentation" in ann for ann in instance["annotations"]) + + return has_mask_annotations + + +def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]: + if not cfg.MODEL.DENSEPOSE_ON: + return None + + use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS + + def has_densepose_annotations(instance: Instance) -> bool: + for ann in instance["annotations"]: + if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all( + key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK + ): + return True + if use_masks and "segmentation" in ann: + return True + return False + + return has_densepose_annotations + + +def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]: + specific_predicate_creators = [ + _maybe_create_keypoints_keep_instance_predicate, + _maybe_create_mask_keep_instance_predicate, + _maybe_create_densepose_keep_instance_predicate, + ] + predicates = [creator(cfg) for creator in specific_predicate_creators] + predicates = [p for p in predicates if p is not None] + if not predicates: + return None + + def combined_predicate(instance: Instance) -> bool: + return any(p(instance) for p in predicates) + + return combined_predicate + + +def _get_train_keep_instance_predicate(cfg: CfgNode): + general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg) + combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg) + + def combined_general_specific_keep_predicate(instance: Instance) -> bool: + return general_keep_predicate(instance) and combined_specific_keep_predicate(instance) + + if (general_keep_predicate is None) and (combined_specific_keep_predicate is None): + return None + if general_keep_predicate is None: + return combined_specific_keep_predicate + if combined_specific_keep_predicate is None: + return general_keep_predicate + return combined_general_specific_keep_predicate + + +def _get_test_keep_instance_predicate(cfg: CfgNode): + general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg) + return general_keep_predicate + + +def _maybe_filter_and_map_categories( + dataset_name: str, dataset_dicts: List[Instance] +) -> List[Instance]: + meta = MetadataCatalog.get(dataset_name) + category_id_map = meta.thing_dataset_id_to_contiguous_id + filtered_dataset_dicts = [] + for dataset_dict in dataset_dicts: + anns = [] + for ann in dataset_dict["annotations"]: + cat_id = ann["category_id"] + if cat_id not in category_id_map: + continue + ann["category_id"] = category_id_map[cat_id] + anns.append(ann) + dataset_dict["annotations"] = anns + filtered_dataset_dicts.append(dataset_dict) + return filtered_dataset_dicts + + +def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None: + for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items(): + meta = MetadataCatalog.get(dataset_name) + meta.whitelisted_categories = whitelisted_cat_ids + logger = logging.getLogger(__name__) + logger.info( + "Whitelisted categories for dataset {}: {}".format( + dataset_name, meta.whitelisted_categories + ) + ) + + +def _add_category_maps_to_metadata(cfg: CfgNode) -> None: + for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items(): + category_map = { + int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items() + } + meta = MetadataCatalog.get(dataset_name) + meta.category_map = category_map + logger = logging.getLogger(__name__) + logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map)) + + +def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None: + meta = MetadataCatalog.get(dataset_name) + meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg) + meta.categories = dataset_cfg.CATEGORIES + meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY + logger = logging.getLogger(__name__) + logger.info( + "Category to class mapping for dataset {}: {}".format( + dataset_name, meta.category_to_class_mapping + ) + ) + + +def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None: + for dataset_name in dataset_names: + meta = MetadataCatalog.get(dataset_name) + if not hasattr(meta, "class_to_mesh_name"): + meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + + +def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT: + merged_categories = defaultdict(list) + category_names = {} + for dataset_name in dataset_names: + meta = MetadataCatalog.get(dataset_name) + whitelisted_categories = meta.get("whitelisted_categories") + category_map = meta.get("category_map", {}) + cat_ids = ( + whitelisted_categories if whitelisted_categories is not None else meta.categories.keys() + ) + for cat_id in cat_ids: + cat_name = meta.categories[cat_id] + cat_id_mapped = category_map.get(cat_id, cat_id) + if cat_id_mapped == cat_id or cat_id_mapped in cat_ids: + category_names[cat_id] = cat_name + else: + category_names[cat_id] = str(cat_id_mapped) + # assign temporary mapped category name, this name can be changed + # during the second pass, since mapped ID can correspond to a category + # from a different dataset + cat_name_mapped = meta.categories[cat_id_mapped] + merged_categories[cat_id_mapped].append( + _DatasetCategory( + id=cat_id, + name=cat_name, + mapped_id=cat_id_mapped, + mapped_name=cat_name_mapped, + dataset_name=dataset_name, + ) + ) + # second pass to assign proper mapped category names + for cat_id, categories in merged_categories.items(): + for cat in categories: + if cat_id in category_names and cat.mapped_name != category_names[cat_id]: + cat.mapped_name = category_names[cat_id] + + return merged_categories + + +def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None: + logger = logging.getLogger(__name__) + for cat_id in merged_categories: + merged_categories_i = merged_categories[cat_id] + first_cat_name = merged_categories_i[0].name + if len(merged_categories_i) > 1 and not all( + cat.name == first_cat_name for cat in merged_categories_i[1:] + ): + cat_summary_str = ", ".join( + [f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i] + ) + logger.warning( + f"Merged category {cat_id} corresponds to the following categories: " + f"{cat_summary_str}" + ) + + +def combine_detection_dataset_dicts( + dataset_names: Collection[str], + keep_instance_predicate: Optional[InstancePredicate] = None, + proposal_files: Optional[Collection[str]] = None, +) -> List[Instance]: + """ + Load and prepare dataset dicts for training / testing + + Args: + dataset_names (Collection[str]): a list of dataset names + keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate + applied to instance dicts which defines whether to keep the instance + proposal_files (Collection[str]): if given, a list of object proposal files + that match each dataset in `dataset_names`. + """ + assert len(dataset_names) + if proposal_files is None: + proposal_files = [None] * len(dataset_names) + assert len(dataset_names) == len(proposal_files) + # load datasets and metadata + dataset_name_to_dicts = {} + for dataset_name in dataset_names: + dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name) + assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!" + # merge categories, requires category metadata to be loaded + # cat_id -> [(orig_cat_id, cat_name, dataset_name)] + merged_categories = _merge_categories(dataset_names) + _warn_if_merged_different_categories(merged_categories) + merged_category_names = [ + merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories) + ] + # map to contiguous category IDs + _add_category_id_to_contiguous_id_maps_to_metadata(merged_categories) + # load annotations and dataset metadata + for dataset_name, proposal_file in zip(dataset_names, proposal_files): + dataset_dicts = dataset_name_to_dicts[dataset_name] + assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!" + if proposal_file is not None: + dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file) + dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts) + print_instances_class_histogram(dataset_dicts, merged_category_names) + dataset_name_to_dicts[dataset_name] = dataset_dicts + + if keep_instance_predicate is not None: + all_datasets_dicts_plain = [ + d + for d in itertools.chain.from_iterable(dataset_name_to_dicts.values()) + if keep_instance_predicate(d) + ] + else: + all_datasets_dicts_plain = list( + itertools.chain.from_iterable(dataset_name_to_dicts.values()) + ) + return all_datasets_dicts_plain + + +def build_detection_train_loader(cfg: CfgNode, mapper=None): + """ + A data loader is created in a way similar to that of Detectron2. + The main differences are: + - it allows to combine datasets with different but compatible object category sets + + The data loader is created by the following steps: + 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. + 2. Start workers to work on the dicts. Each worker will: + * Map each metadata dict into another format to be consumed by the model. + * Batch them by simply putting dicts into a list. + The batched ``list[mapped_dict]`` is what this dataloader will return. + + Args: + cfg (CfgNode): the config + mapper (callable): a callable which takes a sample (dict) from dataset and + returns the format to be consumed by the model. + By default it will be `DatasetMapper(cfg, True)`. + + Returns: + an infinite iterator of training data + """ + + _add_category_whitelists_to_metadata(cfg) + _add_category_maps_to_metadata(cfg) + _maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg) + dataset_dicts = combine_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + keep_instance_predicate=_get_train_keep_instance_predicate(cfg), + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + if mapper is None: + mapper = DatasetMapper(cfg, True) + return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper) + + +def build_detection_test_loader(cfg, dataset_name, mapper=None): + """ + Similar to `build_detection_train_loader`. + But this function uses the given `dataset_name` argument (instead of the names in cfg), + and uses batch size 1. + + Args: + cfg: a detectron2 CfgNode + dataset_name (str): a name of the dataset that's available in the DatasetCatalog + mapper (callable): a callable which takes a sample (dict) from dataset + and returns the format to be consumed by the model. + By default it will be `DatasetMapper(cfg, False)`. + + Returns: + DataLoader: a torch DataLoader, that loads the given detection + dataset, with test-time transformation and batching. + """ + _add_category_whitelists_to_metadata(cfg) + _add_category_maps_to_metadata(cfg) + _maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg) + dataset_dicts = combine_detection_dataset_dicts( + [dataset_name], + keep_instance_predicate=_get_test_keep_instance_predicate(cfg), + proposal_files=[ + cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] + ] + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + sampler = None + if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE: + sampler = torch.utils.data.SequentialSampler(dataset_dicts) + if mapper is None: + mapper = DatasetMapper(cfg, False) + return d2_build_detection_test_loader( + dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler + ) + + +def build_frame_selector(cfg: CfgNode): + strategy = FrameSelectionStrategy(cfg.STRATEGY) + if strategy == FrameSelectionStrategy.RANDOM_K: + frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES) + elif strategy == FrameSelectionStrategy.FIRST_K: + frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES) + elif strategy == FrameSelectionStrategy.LAST_K: + frame_selector = LastKFramesSelector(cfg.NUM_IMAGES) + elif strategy == FrameSelectionStrategy.ALL: + frame_selector = None + # pyre-fixme[61]: `frame_selector` may not be initialized here. + return frame_selector + + +def build_transform(cfg: CfgNode, data_type: str): + if cfg.TYPE == "resize": + if data_type == "image": + return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE) + raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}") + + +def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]): + images_per_worker = _compute_num_images_per_worker(cfg) + return CombinedDataLoader(loaders, images_per_worker, ratios) + + +def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]: + """ + Build dataset that provides data to bootstrap on + + Args: + dataset_name (str): Name of the dataset, needs to have associated metadata + to load the data + cfg (CfgNode): bootstrapping config + Returns: + Sequence[Tensor] - dataset that provides image batches, Tensors of size + [N, C, H, W] of type float32 + """ + logger = logging.getLogger(__name__) + _add_category_info_to_bootstrapping_metadata(dataset_name, cfg) + meta = MetadataCatalog.get(dataset_name) + factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type) + dataset = None + if factory is not None: + dataset = factory(meta, cfg) + if dataset is None: + logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}") + return dataset + + +def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]): + if sampler_cfg.TYPE == "densepose_uniform": + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + elif sampler_cfg.TYPE == "densepose_UV_confidence": + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseConfidenceBasedSampler( + confidence_channel="sigma_2", + count_per_class=sampler_cfg.COUNT_PER_CLASS, + search_proportion=0.5, + ), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + elif sampler_cfg.TYPE == "densepose_fine_segm_confidence": + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseConfidenceBasedSampler( + confidence_channel="fine_segm_confidence", + count_per_class=sampler_cfg.COUNT_PER_CLASS, + search_proportion=0.5, + ), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence": + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseConfidenceBasedSampler( + confidence_channel="coarse_segm_confidence", + count_per_class=sampler_cfg.COUNT_PER_CLASS, + search_proportion=0.5, + ), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + elif sampler_cfg.TYPE == "densepose_cse_uniform": + assert embedder is not None + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseCSEUniformSampler( + cfg=cfg, + use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES, + embedder=embedder, + count_per_class=sampler_cfg.COUNT_PER_CLASS, + ), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence": + assert embedder is not None + data_sampler = PredictionToGroundTruthSampler() + # transform densepose pred -> gt + data_sampler.register_sampler( + "pred_densepose", + "gt_densepose", + DensePoseCSEConfidenceBasedSampler( + cfg=cfg, + use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES, + embedder=embedder, + confidence_channel="coarse_segm_confidence", + count_per_class=sampler_cfg.COUNT_PER_CLASS, + search_proportion=0.5, + ), + ) + data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler()) + return data_sampler + + raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}") + + +def build_data_filter(cfg: CfgNode): + if cfg.TYPE == "detection_score": + min_score = cfg.MIN_VALUE + return ScoreBasedFilter(min_score=min_score) + raise ValueError(f"Unknown data filter type {cfg.TYPE}") + + +def build_inference_based_loader( + cfg: CfgNode, + dataset_cfg: CfgNode, + model: torch.nn.Module, + embedder: Optional[torch.nn.Module] = None, +) -> InferenceBasedLoader: + """ + Constructs data loader based on inference results of a model. + """ + dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER) + meta = MetadataCatalog.get(dataset_cfg.DATASET) + training_sampler = TrainingSampler(len(dataset)) + data_loader = torch.utils.data.DataLoader( + dataset, # pyre-ignore[6] + batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE, + sampler=training_sampler, + num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS, + collate_fn=trivial_batch_collator, + worker_init_fn=worker_init_reset_seed, + ) + return InferenceBasedLoader( + model, + data_loader=data_loader, + data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder), + data_filter=build_data_filter(dataset_cfg.FILTER), + shuffle=True, + batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE, + inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE, + category_to_class_mapping=meta.category_to_class_mapping, + ) + + +def has_inference_based_loaders(cfg: CfgNode) -> bool: + """ + Returns True, if at least one inferense-based loader must + be instantiated for training + """ + return len(cfg.BOOTSTRAP_DATASETS) > 0 + + +def build_inference_based_loaders( + cfg: CfgNode, model: torch.nn.Module +) -> Tuple[List[InferenceBasedLoader], List[float]]: + loaders = [] + ratios = [] + embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16] + for dataset_spec in cfg.BOOTSTRAP_DATASETS: + dataset_cfg = get_bootstrap_dataset_config().clone() + dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec)) + loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder) + loaders.append(loader) + ratios.append(dataset_cfg.RATIO) + return loaders, ratios + + +def build_video_list_dataset(meta: Metadata, cfg: CfgNode): + video_list_fpath = meta.video_list_fpath + video_base_path = meta.video_base_path + category = meta.category + if cfg.TYPE == "video_keyframe": + frame_selector = build_frame_selector(cfg.SELECT) + transform = build_transform(cfg.TRANSFORM, data_type="image") + video_list = video_list_from_file(video_list_fpath, video_base_path) + keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None) + return VideoKeyframeDataset( + video_list, category, frame_selector, transform, keyframe_helper_fpath + ) + + +class _BootstrapDatasetFactoryCatalog(UserDict): + """ + A global dictionary that stores information about bootstrapped datasets creation functions + from metadata and config, for diverse DatasetType + """ + + def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]): + """ + Args: + dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST + factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg + arguments and returns a dataset object. + """ + assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type) + self[dataset_type] = factory + + +BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog() +BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/combined_loader.py b/data_processing/detectron2/projects/DensePose/densepose/data/combined_loader.py new file mode 100644 index 0000000..5bfbbde --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/combined_loader.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +from collections import deque +from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence + +Loader = Iterable[Any] + + +def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]): + if not pool: + pool.extend(next(iterator)) + return pool.popleft() + + +class CombinedDataLoader: + """ + Combines data loaders using the provided sampling ratios + """ + + BATCH_COUNT = 100 + + def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]): + self.loaders = loaders + self.batch_size = batch_size + self.ratios = ratios + + def __iter__(self) -> Iterator[List[Any]]: + iters = [iter(loader) for loader in self.loaders] + indices = [] + pool = [deque()] * len(iters) + # infinite iterator, as in D2 + while True: + if not indices: + # just a buffer of indices, its size doesn't matter + # as long as it's a multiple of batch_size + k = self.batch_size * self.BATCH_COUNT + indices = random.choices(range(len(self.loaders)), self.ratios, k=k) + try: + batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]] + except StopIteration: + break + indices = indices[self.batch_size :] + yield batch diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/dataset_mapper.py b/data_processing/detectron2/projects/DensePose/densepose/data/dataset_mapper.py new file mode 100644 index 0000000..3229c4d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/dataset_mapper.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import logging +from typing import Any, Dict, List, Tuple +import torch + +from detectron2.data import MetadataCatalog +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.layers import ROIAlign +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + +from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData + + +def build_augmentation(cfg, is_train): + logger = logging.getLogger(__name__) + result = utils.build_augmentation(cfg, is_train) + if is_train: + random_rotation = T.RandomRotation( + cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice" + ) + result.append(random_rotation) + logger.info("DensePose-specific augmentation used in training: " + str(random_rotation)) + return result + + +class DatasetMapper: + """ + A customized version of `detectron2.data.DatasetMapper` + """ + + def __init__(self, cfg, is_train=True): + self.augmentation = build_augmentation(cfg, is_train) + + # fmt: off + self.img_format = cfg.INPUT.FORMAT + self.mask_on = ( + cfg.MODEL.MASK_ON or ( + cfg.MODEL.DENSEPOSE_ON + and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS) + ) + self.keypoint_on = cfg.MODEL.KEYPOINT_ON + self.densepose_on = cfg.MODEL.DENSEPOSE_ON + assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet" + # fmt: on + if self.keypoint_on and is_train: + # Flip only makes sense in training + self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + else: + self.keypoint_hflip_indices = None + + if self.densepose_on: + densepose_transform_srcs = [ + MetadataCatalog.get(ds).densepose_transform_src + for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST + ] + assert len(densepose_transform_srcs) > 0 + # TODO: check that DensePose transformation data is the same for + # all the datasets. Otherwise one would have to pass DB ID with + # each entry to select proper transformation data. For now, since + # all DensePose annotated data uses the same data semantics, we + # omit this check. + densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0]) + self.densepose_transform_data = DensePoseTransformData.load( + densepose_transform_data_fpath + ) + + self.is_train = is_train + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + image, transforms = T.apply_transform_gens(self.augmentation, image) + image_shape = image.shape[:2] # h, w + dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32")) + + if not self.is_train: + dataset_dict.pop("annotations", None) + return dataset_dict + + for anno in dataset_dict["annotations"]: + if not self.mask_on: + anno.pop("segmentation", None) + if not self.keypoint_on: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + # USER: Don't call transpose_densepose if you don't need + annos = [ + self._transform_densepose( + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ), + transforms, + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + + if self.mask_on: + self._add_densepose_masks_as_segmentation(annos, image_shape) + + instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask") + densepose_annotations = [obj.get("densepose") for obj in annos] + if densepose_annotations and not all(v is None for v in densepose_annotations): + instances.gt_densepose = DensePoseList( + densepose_annotations, instances.gt_boxes, image_shape + ) + + dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()] + return dataset_dict + + def _transform_densepose(self, annotation, transforms): + if not self.densepose_on: + return annotation + + # Handle densepose annotations + is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation) + if is_valid: + densepose_data = DensePoseDataRelative(annotation, cleanup=True) + densepose_data.apply_transform(transforms, self.densepose_transform_data) + annotation["densepose"] = densepose_data + else: + # logger = logging.getLogger(__name__) + # logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid)) + DensePoseDataRelative.cleanup_annotation(annotation) + # NOTE: annotations for certain instances may be unavailable. + # 'None' is accepted by the DensePostList data structure. + annotation["densepose"] = None + return annotation + + def _add_densepose_masks_as_segmentation( + self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int] + ): + for obj in annotations: + if ("densepose" not in obj) or ("segmentation" in obj): + continue + # DP segmentation: torch.Tensor [S, S] of float32, S=256 + segm_dp = torch.zeros_like(obj["densepose"].segm) + segm_dp[obj["densepose"].segm > 0] = 1 + segm_h, segm_w = segm_dp.shape + bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32) + # image bbox + x0, y0, x1, y1 = ( + v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) + ) + segm_aligned = ( + ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True) + .forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp) + .squeeze() + ) + image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32) + image_mask[y0:y1, x0:x1] = segm_aligned + # segmentation for BitMask: np.array [H, W] of bool + obj["segmentation"] = image_mask >= 0.5 diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/__init__.py new file mode 100644 index 0000000..260ccb9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from . import builtin # ensure the builtin datasets are registered + +__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/builtin.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/builtin.py new file mode 100644 index 0000000..7572cd6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/builtin.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .chimpnsee import register_dataset as register_chimpnsee_dataset +from .coco import BASE_DATASETS as BASE_COCO_DATASETS +from .coco import DATASETS as COCO_DATASETS +from .coco import register_datasets as register_coco_datasets +from .lvis import DATASETS as LVIS_DATASETS +from .lvis import register_datasets as register_lvis_datasets + +DEFAULT_DATASETS_ROOT = "datasets" + + +register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT) +register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT) +register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT) + +register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/chimpnsee.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/chimpnsee.py new file mode 100644 index 0000000..61e0b50 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/chimpnsee.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Optional + +from detectron2.data import DatasetCatalog, MetadataCatalog + +from ..utils import maybe_prepend_base_path +from .dataset_type import DatasetType + +CHIMPNSEE_DATASET_NAME = "chimpnsee" + + +def register_dataset(datasets_root: Optional[str] = None) -> None: + def empty_load_callback(): + pass + + video_list_fpath = maybe_prepend_base_path( + datasets_root, + "chimpnsee/cdna.eva.mpg.de/video_list.txt", + ) + video_base_path = maybe_prepend_base_path(datasets_root, "chimpnsee/cdna.eva.mpg.de") + + DatasetCatalog.register(CHIMPNSEE_DATASET_NAME, empty_load_callback) + MetadataCatalog.get(CHIMPNSEE_DATASET_NAME).set( + dataset_type=DatasetType.VIDEO_LIST, + video_list_fpath=video_list_fpath, + video_base_path=video_base_path, + category="chimpanzee", + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/coco.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/coco.py new file mode 100644 index 0000000..c19f7b0 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/coco.py @@ -0,0 +1,432 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import io +import logging +import os +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional +from fvcore.common.timer import Timer + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + +from ..utils import maybe_prepend_base_path + +DENSEPOSE_MASK_KEY = "dp_masks" +DENSEPOSE_IUV_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_I", "dp_U", "dp_V"] +DENSEPOSE_CSE_KEYS_WITHOUT_MASK = ["dp_x", "dp_y", "dp_vertex", "ref_model"] +DENSEPOSE_ALL_POSSIBLE_KEYS = set( + DENSEPOSE_IUV_KEYS_WITHOUT_MASK + DENSEPOSE_CSE_KEYS_WITHOUT_MASK + [DENSEPOSE_MASK_KEY] +) +DENSEPOSE_METADATA_URL_PREFIX = "https://dl.fbaipublicfiles.com/densepose/data/" + + +@dataclass +class CocoDatasetInfo: + name: str + images_root: str + annotations_fpath: str + + +DATASETS = [ + CocoDatasetInfo( + name="densepose_coco_2014_train", + images_root="coco/train2014", + annotations_fpath="coco/annotations/densepose_train2014.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_minival", + images_root="coco/val2014", + annotations_fpath="coco/annotations/densepose_minival2014.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_minival_100", + images_root="coco/val2014", + annotations_fpath="coco/annotations/densepose_minival2014_100.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_valminusminival", + images_root="coco/val2014", + annotations_fpath="coco/annotations/densepose_valminusminival2014.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_train_cse", + images_root="coco/train2014", + annotations_fpath="coco_cse/densepose_train2014_cse.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_minival_cse", + images_root="coco/val2014", + annotations_fpath="coco_cse/densepose_minival2014_cse.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_minival_100_cse", + images_root="coco/val2014", + annotations_fpath="coco_cse/densepose_minival2014_100_cse.json", + ), + CocoDatasetInfo( + name="densepose_coco_2014_valminusminival_cse", + images_root="coco/val2014", + annotations_fpath="coco_cse/densepose_valminusminival2014_cse.json", + ), + CocoDatasetInfo( + name="densepose_chimps", + images_root="densepose_chimps/images", + annotations_fpath="densepose_chimps/densepose_chimps_densepose.json", + ), + CocoDatasetInfo( + name="densepose_chimps_cse_train", + images_root="densepose_chimps/images", + annotations_fpath="densepose_chimps/densepose_chimps_cse_train.json", + ), + CocoDatasetInfo( + name="densepose_chimps_cse_val", + images_root="densepose_chimps/images", + annotations_fpath="densepose_chimps/densepose_chimps_cse_val.json", + ), + CocoDatasetInfo( + name="posetrack2017_train", + images_root="posetrack2017/posetrack_data_2017", + annotations_fpath="posetrack2017/densepose_posetrack_train2017.json", + ), + CocoDatasetInfo( + name="posetrack2017_val", + images_root="posetrack2017/posetrack_data_2017", + annotations_fpath="posetrack2017/densepose_posetrack_val2017.json", + ), + CocoDatasetInfo( + name="lvis_v05_train", + images_root="coco/train2017", + annotations_fpath="lvis/lvis_v0.5_plus_dp_train.json", + ), + CocoDatasetInfo( + name="lvis_v05_val", + images_root="coco/val2017", + annotations_fpath="lvis/lvis_v0.5_plus_dp_val.json", + ), +] + + +BASE_DATASETS = [ + CocoDatasetInfo( + name="base_coco_2017_train", + images_root="coco/train2017", + annotations_fpath="coco/annotations/instances_train2017.json", + ), + CocoDatasetInfo( + name="base_coco_2017_val", + images_root="coco/val2017", + annotations_fpath="coco/annotations/instances_val2017.json", + ), + CocoDatasetInfo( + name="base_coco_2017_val_100", + images_root="coco/val2017", + annotations_fpath="coco/annotations/instances_val2017_100.json", + ), +] + + +def get_metadata(base_path: Optional[str]) -> Dict[str, Any]: + """ + Returns metadata associated with COCO DensePose datasets + + Args: + base_path: Optional[str] + Base path used to load metadata from + + Returns: + Dict[str, Any] + Metadata in the form of a dictionary + """ + meta = { + "densepose_transform_src": maybe_prepend_base_path(base_path, "UV_symmetry_transforms.mat"), + "densepose_smpl_subdiv": maybe_prepend_base_path(base_path, "SMPL_subdiv.mat"), + "densepose_smpl_subdiv_transform": maybe_prepend_base_path( + base_path, + "SMPL_SUBDIV_TRANSFORM.mat", + ), + } + return meta + + +def _load_coco_annotations(json_file: str): + """ + Load COCO annotations from a JSON file + + Args: + json_file: str + Path to the file to load annotations from + Returns: + Instance of `pycocotools.coco.COCO` that provides access to annotations + data + """ + from pycocotools.coco import COCO + + logger = logging.getLogger(__name__) + timer = Timer() + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCO(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + return coco_api + + +def _add_categories_metadata(dataset_name: str, categories: List[Dict[str, Any]]): + meta = MetadataCatalog.get(dataset_name) + meta.categories = {c["id"]: c["name"] for c in categories} + logger = logging.getLogger(__name__) + logger.info("Dataset {} categories: {}".format(dataset_name, meta.categories)) + + +def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]): + if "minival" in json_file: + # Skip validation on COCO2014 valminusminival and minival annotations + # The ratio of buggy annotations there is tiny and does not affect accuracy + # Therefore we explicitly white-list them + return + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( + json_file + ) + + +def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]): + if "bbox" not in ann_dict: + return + obj["bbox"] = ann_dict["bbox"] + obj["bbox_mode"] = BoxMode.XYWH_ABS + + +def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]): + if "segmentation" not in ann_dict: + return + segm = ann_dict["segmentation"] + if not isinstance(segm, dict): + # filter out invalid polygons (< 3 points) + segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + if len(segm) == 0: + return + obj["segmentation"] = segm + + +def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]): + if "keypoints" not in ann_dict: + return + keypts = ann_dict["keypoints"] # list[int] + for idx, v in enumerate(keypts): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # Therefore we assume the coordinates are "pixel indices" and + # add 0.5 to convert to floating point coordinates. + keypts[idx] = v + 0.5 + obj["keypoints"] = keypts + + +def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]): + for key in DENSEPOSE_ALL_POSSIBLE_KEYS: + if key in ann_dict: + obj[key] = ann_dict[key] + + +def _combine_images_with_annotations( + dataset_name: str, + image_root: str, + img_datas: Iterable[Dict[str, Any]], + ann_datas: Iterable[Iterable[Dict[str, Any]]], +): + + ann_keys = ["iscrowd", "category_id"] + dataset_dicts = [] + contains_video_frame_info = False + + for img_dict, ann_dicts in zip(img_datas, ann_datas): + record = {} + record["file_name"] = os.path.join(image_root, img_dict["file_name"]) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + record["image_id"] = img_dict["id"] + record["dataset"] = dataset_name + if "frame_id" in img_dict: + record["frame_id"] = img_dict["frame_id"] + record["video_id"] = img_dict.get("vid_id", None) + contains_video_frame_info = True + objs = [] + for ann_dict in ann_dicts: + assert ann_dict["image_id"] == record["image_id"] + assert ann_dict.get("ignore", 0) == 0 + obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict} + _maybe_add_bbox(obj, ann_dict) + _maybe_add_segm(obj, ann_dict) + _maybe_add_keypoints(obj, ann_dict) + _maybe_add_densepose(obj, ann_dict) + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + if contains_video_frame_info: + create_video_frame_mapping(dataset_name, dataset_dicts) + return dataset_dicts + + +def get_contiguous_id_to_category_id_map(metadata): + cat_id_2_cont_id = metadata.thing_dataset_id_to_contiguous_id + cont_id_2_cat_id = {} + for cat_id, cont_id in cat_id_2_cont_id.items(): + if cont_id in cont_id_2_cat_id: + continue + cont_id_2_cat_id[cont_id] = cat_id + return cont_id_2_cat_id + + +def maybe_filter_categories_cocoapi(dataset_name, coco_api): + meta = MetadataCatalog.get(dataset_name) + cont_id_2_cat_id = get_contiguous_id_to_category_id_map(meta) + cat_id_2_cont_id = meta.thing_dataset_id_to_contiguous_id + # filter categories + cats = [] + for cat in coco_api.dataset["categories"]: + cat_id = cat["id"] + if cat_id not in cat_id_2_cont_id: + continue + cont_id = cat_id_2_cont_id[cat_id] + if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id): + cats.append(cat) + coco_api.dataset["categories"] = cats + # filter annotations, if multiple categories are mapped to a single + # contiguous ID, use only one category ID and map all annotations to that category ID + anns = [] + for ann in coco_api.dataset["annotations"]: + cat_id = ann["category_id"] + if cat_id not in cat_id_2_cont_id: + continue + cont_id = cat_id_2_cont_id[cat_id] + ann["category_id"] = cont_id_2_cat_id[cont_id] + anns.append(ann) + coco_api.dataset["annotations"] = anns + # recreate index + coco_api.createIndex() + + +def maybe_filter_and_map_categories_cocoapi(dataset_name, coco_api): + meta = MetadataCatalog.get(dataset_name) + category_id_map = meta.thing_dataset_id_to_contiguous_id + # map categories + cats = [] + for cat in coco_api.dataset["categories"]: + cat_id = cat["id"] + if cat_id not in category_id_map: + continue + cat["id"] = category_id_map[cat_id] + cats.append(cat) + coco_api.dataset["categories"] = cats + # map annotation categories + anns = [] + for ann in coco_api.dataset["annotations"]: + cat_id = ann["category_id"] + if cat_id not in category_id_map: + continue + ann["category_id"] = category_id_map[cat_id] + anns.append(ann) + coco_api.dataset["annotations"] = anns + # recreate index + coco_api.createIndex() + + +def create_video_frame_mapping(dataset_name, dataset_dicts): + mapping = defaultdict(dict) + for d in dataset_dicts: + video_id = d.get("video_id") + if video_id is None: + continue + mapping[video_id].update({d["frame_id"]: d["file_name"]}) + MetadataCatalog.get(dataset_name).set(video_frame_mapping=mapping) + + +def load_coco_json(annotations_json_file: str, image_root: str, dataset_name: str): + """ + Loads a JSON file with annotations in COCO instances format. + Replaces `detectron2.data.datasets.coco.load_coco_json` to handle metadata + in a more flexible way. Postpones category mapping to a later stage to be + able to combine several datasets with different (but coherent) sets of + categories. + + Args: + + annotations_json_file: str + Path to the JSON file with annotations in COCO instances format. + image_root: str + directory that contains all the images + dataset_name: str + the name that identifies a dataset, e.g. "densepose_coco_2014_train" + extra_annotation_keys: Optional[List[str]] + If provided, these keys are used to extract additional data from + the annotations. + """ + coco_api = _load_coco_annotations(PathManager.get_local_path(annotations_json_file)) + _add_categories_metadata(dataset_name, coco_api.loadCats(coco_api.getCatIds())) + # sort indices for reproducible results + img_ids = sorted(coco_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = coco_api.loadImgs(img_ids) + logger = logging.getLogger(__name__) + logger.info("Loaded {} images in COCO format from {}".format(len(imgs), annotations_json_file)) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. + anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] + _verify_annotations_have_unique_ids(annotations_json_file, anns) + dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns) + return dataset_records + + +def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None): + """ + Registers provided COCO DensePose dataset + + Args: + dataset_data: CocoDatasetInfo + Dataset data + datasets_root: Optional[str] + Datasets root folder (default: None) + """ + annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath) + images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root) + + def load_annotations(): + return load_coco_json( + annotations_json_file=annotations_fpath, + image_root=images_root, + dataset_name=dataset_data.name, + ) + + DatasetCatalog.register(dataset_data.name, load_annotations) + MetadataCatalog.get(dataset_data.name).set( + json_file=annotations_fpath, + image_root=images_root, + **get_metadata(DENSEPOSE_METADATA_URL_PREFIX) + ) + + +def register_datasets( + datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None +): + """ + Registers provided COCO DensePose datasets + + Args: + datasets_data: Iterable[CocoDatasetInfo] + An iterable of dataset datas + datasets_root: Optional[str] + Datasets root folder (default: None) + """ + for dataset_data in datasets_data: + register_dataset(dataset_data, datasets_root) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/dataset_type.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/dataset_type.py new file mode 100644 index 0000000..ed8f8f2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/dataset_type.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from enum import Enum + + +class DatasetType(Enum): + """ + Dataset type, mostly used for datasets that contain data to bootstrap models on + """ + + VIDEO_LIST = "video_list" diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/datasets/lvis.py b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/lvis.py new file mode 100644 index 0000000..b4af9fa --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/datasets/lvis.py @@ -0,0 +1,257 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os +from typing import Any, Dict, Iterable, List, Optional +from fvcore.common.timer import Timer + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.lvis import get_lvis_instances_meta +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + +from ..utils import maybe_prepend_base_path +from .coco import ( + DENSEPOSE_ALL_POSSIBLE_KEYS, + DENSEPOSE_METADATA_URL_PREFIX, + CocoDatasetInfo, + get_metadata, +) + +DATASETS = [ + CocoDatasetInfo( + name="densepose_lvis_v1_ds1_train_v1", + images_root="coco_", + annotations_fpath="lvis/densepose_lvis_v1_ds1_train_v1.json", + ), + CocoDatasetInfo( + name="densepose_lvis_v1_ds1_val_v1", + images_root="coco_", + annotations_fpath="lvis/densepose_lvis_v1_ds1_val_v1.json", + ), + CocoDatasetInfo( + name="densepose_lvis_v1_ds2_train_v1", + images_root="coco_", + annotations_fpath="lvis/densepose_lvis_v1_ds2_train_v1.json", + ), + CocoDatasetInfo( + name="densepose_lvis_v1_ds2_val_v1", + images_root="coco_", + annotations_fpath="lvis/densepose_lvis_v1_ds2_val_v1.json", + ), + CocoDatasetInfo( + name="densepose_lvis_v1_ds1_val_animals_100", + images_root="coco_", + annotations_fpath="lvis/densepose_lvis_v1_val_animals_100_v2.json", + ), +] + + +def _load_lvis_annotations(json_file: str): + """ + Load COCO annotations from a JSON file + + Args: + json_file: str + Path to the file to load annotations from + Returns: + Instance of `pycocotools.coco.COCO` that provides access to annotations + data + """ + from lvis import LVIS + + json_file = PathManager.get_local_path(json_file) + logger = logging.getLogger(__name__) + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + return lvis_api + + +def _add_categories_metadata(dataset_name: str) -> None: + metadict = get_lvis_instances_meta(dataset_name) + categories = metadict["thing_classes"] + metadata = MetadataCatalog.get(dataset_name) + metadata.categories = {i + 1: categories[i] for i in range(len(categories))} + logger = logging.getLogger(__name__) + logger.info(f"Dataset {dataset_name} has {len(categories)} categories") + + +def _verify_annotations_have_unique_ids(json_file: str, anns: List[List[Dict[str, Any]]]) -> None: + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( + json_file + ) + + +def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None: + if "bbox" not in ann_dict: + return + obj["bbox"] = ann_dict["bbox"] + obj["bbox_mode"] = BoxMode.XYWH_ABS + + +def _maybe_add_segm(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None: + if "segmentation" not in ann_dict: + return + segm = ann_dict["segmentation"] + if not isinstance(segm, dict): + # filter out invalid polygons (< 3 points) + segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + if len(segm) == 0: + return + obj["segmentation"] = segm + + +def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None: + if "keypoints" not in ann_dict: + return + keypts = ann_dict["keypoints"] # list[int] + for idx, v in enumerate(keypts): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # Therefore we assume the coordinates are "pixel indices" and + # add 0.5 to convert to floating point coordinates. + keypts[idx] = v + 0.5 + obj["keypoints"] = keypts + + +def _maybe_add_densepose(obj: Dict[str, Any], ann_dict: Dict[str, Any]) -> None: + for key in DENSEPOSE_ALL_POSSIBLE_KEYS: + if key in ann_dict: + obj[key] = ann_dict[key] + + +def _combine_images_with_annotations( + dataset_name: str, + image_root: str, + img_datas: Iterable[Dict[str, Any]], + ann_datas: Iterable[Iterable[Dict[str, Any]]], +): + + dataset_dicts = [] + + def get_file_name(img_root, img_dict): + # Determine the path including the split folder ("train2017", "val2017", "test2017") from + # the coco_url field. Example: + # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg' + split_folder, file_name = img_dict["coco_url"].split("/")[-2:] + return os.path.join(img_root + split_folder, file_name) + + for img_dict, ann_dicts in zip(img_datas, ann_datas): + record = {} + record["file_name"] = get_file_name(image_root, img_dict) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", []) + record["neg_category_ids"] = img_dict.get("neg_category_ids", []) + record["image_id"] = img_dict["id"] + record["dataset"] = dataset_name + + objs = [] + for ann_dict in ann_dicts: + assert ann_dict["image_id"] == record["image_id"] + obj = {} + _maybe_add_bbox(obj, ann_dict) + obj["iscrowd"] = ann_dict.get("iscrowd", 0) + obj["category_id"] = ann_dict["category_id"] + _maybe_add_segm(obj, ann_dict) + _maybe_add_keypoints(obj, ann_dict) + _maybe_add_densepose(obj, ann_dict) + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + return dataset_dicts + + +def load_lvis_json(annotations_json_file: str, image_root: str, dataset_name: str): + """ + Loads a JSON file with annotations in LVIS instances format. + Replaces `detectron2.data.datasets.coco.load_lvis_json` to handle metadata + in a more flexible way. Postpones category mapping to a later stage to be + able to combine several datasets with different (but coherent) sets of + categories. + + Args: + + annotations_json_file: str + Path to the JSON file with annotations in COCO instances format. + image_root: str + directory that contains all the images + dataset_name: str + the name that identifies a dataset, e.g. "densepose_coco_2014_train" + extra_annotation_keys: Optional[List[str]] + If provided, these keys are used to extract additional data from + the annotations. + """ + lvis_api = _load_lvis_annotations(PathManager.get_local_path(annotations_json_file)) + + _add_categories_metadata(dataset_name) + + # sort indices for reproducible results + img_ids = sorted(lvis_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = lvis_api.load_imgs(img_ids) + logger = logging.getLogger(__name__) + logger.info("Loaded {} images in LVIS format from {}".format(len(imgs), annotations_json_file)) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + _verify_annotations_have_unique_ids(annotations_json_file, anns) + dataset_records = _combine_images_with_annotations(dataset_name, image_root, imgs, anns) + return dataset_records + + +def register_dataset(dataset_data: CocoDatasetInfo, datasets_root: Optional[str] = None) -> None: + """ + Registers provided LVIS DensePose dataset + + Args: + dataset_data: CocoDatasetInfo + Dataset data + datasets_root: Optional[str] + Datasets root folder (default: None) + """ + annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_data.annotations_fpath) + images_root = maybe_prepend_base_path(datasets_root, dataset_data.images_root) + + def load_annotations(): + return load_lvis_json( + annotations_json_file=annotations_fpath, + image_root=images_root, + dataset_name=dataset_data.name, + ) + + DatasetCatalog.register(dataset_data.name, load_annotations) + MetadataCatalog.get(dataset_data.name).set( + json_file=annotations_fpath, + image_root=images_root, + evaluator_type="lvis", + **get_metadata(DENSEPOSE_METADATA_URL_PREFIX), + ) + + +def register_datasets( + datasets_data: Iterable[CocoDatasetInfo], datasets_root: Optional[str] = None +) -> None: + """ + Registers provided LVIS DensePose datasets + + Args: + datasets_data: Iterable[CocoDatasetInfo] + An iterable of dataset datas + datasets_root: Optional[str] + Datasets root folder (default: None) + """ + for dataset_data in datasets_data: + register_dataset(dataset_data, datasets_root) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/image_list_dataset.py b/data_processing/detectron2/projects/DensePose/densepose/data/image_list_dataset.py new file mode 100644 index 0000000..92a95d3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/image_list_dataset.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +from typing import Any, Callable, Dict, List, Optional, Union +import torch +from torch.utils.data.dataset import Dataset + +from detectron2.data.detection_utils import read_image + +ImageTransform = Callable[[torch.Tensor], torch.Tensor] + + +class ImageListDataset(Dataset): + """ + Dataset that provides images from a list. + """ + + _EMPTY_IMAGE = torch.empty((0, 3, 1, 1)) + + def __init__( + self, + image_list: List[str], + category_list: Union[str, List[str], None] = None, + transform: Optional[ImageTransform] = None, + ): + """ + Args: + image_list (List[str]): list of paths to image files + category_list (Union[str, List[str], None]): list of animal categories for + each image. If it is a string, or None, this applies to all images + """ + if type(category_list) == list: + self.category_list = category_list + else: + self.category_list = [category_list] * len(image_list) + assert len(image_list) == len( + self.category_list + ), "length of image and category lists must be equal" + self.image_list = image_list + self.transform = transform + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Gets selected images from the list + + Args: + idx (int): video index in the video list file + Returns: + A dictionary containing two keys: + images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE) + categories (List[str]): categories of the frames + """ + categories = [self.category_list[idx]] + fpath = self.image_list[idx] + transform = self.transform + + try: + image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR"))) + image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW + if transform is not None: + image = transform(image) + return {"images": image, "categories": categories} + except (OSError, RuntimeError) as e: + logger = logging.getLogger(__name__) + logger.warning(f"Error opening image file container {fpath}: {e}") + + return {"images": self._EMPTY_IMAGE, "categories": []} + + def __len__(self): + return len(self.image_list) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/inference_based_loader.py b/data_processing/detectron2/projects/DensePose/densepose/data/inference_based_loader.py new file mode 100644 index 0000000..dde4c0f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/inference_based_loader.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple +import torch +from torch import nn + +SampledData = Any +ModelOutput = Any + + +def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: + """ + Group elements of an iterable by chunks of size `n`, e.g. + grouper(range(9), 4) -> + (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) + """ + it = iter(iterable) + while True: + values = [] + for _ in range(n): + try: + value = next(it) + except StopIteration: + if values: + values.extend([fillvalue] * (n - len(values))) + yield tuple(values) + return + values.append(value) + yield tuple(values) + + +class ScoreBasedFilter: + """ + Filters entries in model output based on their scores + Discards all entries with score less than the specified minimum + """ + + def __init__(self, min_score: float = 0.8): + self.min_score = min_score + + def __call__(self, model_output: ModelOutput) -> ModelOutput: + for model_output_i in model_output: + instances = model_output_i["instances"] + if not instances.has("scores"): + continue + print('in inference based loader') + instances_filtered = instances[instances.scores >= self.min_score] + model_output_i["instances"] = instances_filtered + return model_output + + +class InferenceBasedLoader: + """ + Data loader based on results inferred by a model. Consists of: + - a data loader that provides batches of images + - a model that is used to infer the results + - a data sampler that converts inferred results to annotations + """ + + def __init__( + self, + model: nn.Module, + data_loader: Iterable[List[Dict[str, Any]]], + data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None, + data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None, + shuffle: bool = True, + batch_size: int = 4, + inference_batch_size: int = 4, + drop_last: bool = False, + category_to_class_mapping: Optional[dict] = None, + ): + """ + Constructor + + Args: + model (torch.nn.Module): model used to produce data + data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides + dictionaries with "images" and "categories" fields to perform inference on + data_sampler (Callable: ModelOutput -> SampledData): functor + that produces annotation data from inference results; + (optional, default: None) + data_filter (Callable: ModelOutput -> ModelOutput): filter + that selects model outputs for further processing + (optional, default: None) + shuffle (bool): if True, the input images get shuffled + batch_size (int): batch size for the produced annotation data + inference_batch_size (int): batch size for input images + drop_last (bool): if True, drop the last batch if it is undersized + category_to_class_mapping (dict): category to class mapping + """ + self.model = model + self.model.eval() + self.data_loader = data_loader + self.data_sampler = data_sampler + self.data_filter = data_filter + self.shuffle = shuffle + self.batch_size = batch_size + self.inference_batch_size = inference_batch_size + self.drop_last = drop_last + if category_to_class_mapping is not None: + self.category_to_class_mapping = category_to_class_mapping + else: + self.category_to_class_mapping = {} + + def __iter__(self) -> Iterator[List[SampledData]]: + for batch in self.data_loader: + # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]] + # images_batch : Tensor[N, C, H, W] + # image : Tensor[C, H, W] + images_and_categories = [ + {"image": image, "category": category} + for element in batch + for image, category in zip(element["images"], element["categories"]) + ] + if not images_and_categories: + continue + if self.shuffle: + random.shuffle(images_and_categories) + yield from self._produce_data(images_and_categories) # pyre-ignore[6] + + def _produce_data( + self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]] + ) -> Iterator[List[SampledData]]: + """ + Produce batches of data from images + + Args: + images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]): + list of images and corresponding categories to process + + Returns: + Iterator over batches of data sampled from model outputs + """ + data_batches: List[SampledData] = [] + category_to_class_mapping = self.category_to_class_mapping + batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size) + for batch in batched_images_and_categories: + batch = [ + { + "image": image_and_category["image"].to(self.model.device), + "category": image_and_category["category"], + } + for image_and_category in batch + if image_and_category is not None + ] + if not batch: + continue + with torch.no_grad(): + model_output = self.model(batch) + for model_output_i, batch_i in zip(model_output, batch): + assert len(batch_i["image"].shape) == 3 + model_output_i["image"] = batch_i["image"] + instance_class = category_to_class_mapping.get(batch_i["category"], 0) + model_output_i["instances"].dataset_classes = torch.tensor( + [instance_class] * len(model_output_i["instances"]) + ) + model_output_filtered = ( + model_output if self.data_filter is None else self.data_filter(model_output) + ) + data = ( + model_output_filtered + if self.data_sampler is None + else self.data_sampler(model_output_filtered) + ) + for data_i in data: + if len(data_i["instances"]): + data_batches.append(data_i) + if len(data_batches) >= self.batch_size: + yield data_batches[: self.batch_size] + data_batches = data_batches[self.batch_size :] + if not self.drop_last and data_batches: + yield data_batches diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/meshes/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/__init__.py new file mode 100644 index 0000000..1e1f0d5 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from . import builtin + +__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/meshes/builtin.py b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/builtin.py new file mode 100644 index 0000000..c0b2376 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/builtin.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from .catalog import MeshInfo, register_meshes + +DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/" + +MESHES = [ + MeshInfo( + name="smpl_27554", + data="smpl_27554.pkl", + geodists="geodists/geodists_smpl_27554.pkl", + symmetry="symmetry/symmetry_smpl_27554.pkl", + texcoords="texcoords/texcoords_smpl_27554.pkl", + ), + MeshInfo( + name="chimp_5029", + data="chimp_5029.pkl", + geodists="geodists/geodists_chimp_5029.pkl", + symmetry="symmetry/symmetry_chimp_5029.pkl", + texcoords="texcoords/texcoords_chimp_5029.pkl", + ), + MeshInfo( + name="cat_5001", + data="cat_5001.pkl", + geodists="geodists/geodists_cat_5001.pkl", + symmetry="symmetry/symmetry_cat_5001.pkl", + texcoords="texcoords/texcoords_cat_5001.pkl", + ), + MeshInfo( + name="cat_7466", + data="cat_7466.pkl", + geodists="geodists/geodists_cat_7466.pkl", + symmetry="symmetry/symmetry_cat_7466.pkl", + texcoords="texcoords/texcoords_cat_7466.pkl", + ), + MeshInfo( + name="sheep_5004", + data="sheep_5004.pkl", + geodists="geodists/geodists_sheep_5004.pkl", + symmetry="symmetry/symmetry_sheep_5004.pkl", + texcoords="texcoords/texcoords_sheep_5004.pkl", + ), + MeshInfo( + name="zebra_5002", + data="zebra_5002.pkl", + geodists="geodists/geodists_zebra_5002.pkl", + symmetry="symmetry/symmetry_zebra_5002.pkl", + texcoords="texcoords/texcoords_zebra_5002.pkl", + ), + MeshInfo( + name="horse_5004", + data="horse_5004.pkl", + geodists="geodists/geodists_horse_5004.pkl", + symmetry="symmetry/symmetry_horse_5004.pkl", + texcoords="texcoords/texcoords_zebra_5002.pkl", + ), + MeshInfo( + name="giraffe_5002", + data="giraffe_5002.pkl", + geodists="geodists/geodists_giraffe_5002.pkl", + symmetry="symmetry/symmetry_giraffe_5002.pkl", + texcoords="texcoords/texcoords_giraffe_5002.pkl", + ), + MeshInfo( + name="elephant_5002", + data="elephant_5002.pkl", + geodists="geodists/geodists_elephant_5002.pkl", + symmetry="symmetry/symmetry_elephant_5002.pkl", + texcoords="texcoords/texcoords_elephant_5002.pkl", + ), + MeshInfo( + name="dog_5002", + data="dog_5002.pkl", + geodists="geodists/geodists_dog_5002.pkl", + symmetry="symmetry/symmetry_dog_5002.pkl", + texcoords="texcoords/texcoords_dog_5002.pkl", + ), + MeshInfo( + name="dog_7466", + data="dog_7466.pkl", + geodists="geodists/geodists_dog_7466.pkl", + symmetry="symmetry/symmetry_dog_7466.pkl", + texcoords="texcoords/texcoords_dog_7466.pkl", + ), + MeshInfo( + name="cow_5002", + data="cow_5002.pkl", + geodists="geodists/geodists_cow_5002.pkl", + symmetry="symmetry/symmetry_cow_5002.pkl", + texcoords="texcoords/texcoords_cow_5002.pkl", + ), + MeshInfo( + name="bear_4936", + data="bear_4936.pkl", + geodists="geodists/geodists_bear_4936.pkl", + symmetry="symmetry/symmetry_bear_4936.pkl", + texcoords="texcoords/texcoords_bear_4936.pkl", + ), +] + +register_meshes(MESHES, DENSEPOSE_MESHES_DIR) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/meshes/catalog.py b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/catalog.py new file mode 100644 index 0000000..b258f3c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/meshes/catalog.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +from collections import UserDict +from dataclasses import dataclass +from typing import Iterable, Optional + +from ..utils import maybe_prepend_base_path + + +@dataclass +class MeshInfo: + name: str + data: str + geodists: Optional[str] = None + symmetry: Optional[str] = None + texcoords: Optional[str] = None + + +class _MeshCatalog(UserDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mesh_ids = {} + self.mesh_names = {} + self.max_mesh_id = -1 + + def __setitem__(self, key, value): + if key in self: + logger = logging.getLogger(__name__) + logger.warning( + f"Overwriting mesh catalog entry '{key}': old value {self[key]}" + f", new value {value}" + ) + mesh_id = self.mesh_ids[key] + else: + self.max_mesh_id += 1 + mesh_id = self.max_mesh_id + super().__setitem__(key, value) + self.mesh_ids[key] = mesh_id + self.mesh_names[mesh_id] = key + + def get_mesh_id(self, shape_name: str) -> int: + return self.mesh_ids[shape_name] + + def get_mesh_name(self, mesh_id: int) -> str: + return self.mesh_names[mesh_id] + + +MeshCatalog = _MeshCatalog() + + +def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None: + geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords + if geodists: + geodists = maybe_prepend_base_path(base_path, geodists) + if symmetry: + symmetry = maybe_prepend_base_path(base_path, symmetry) + if texcoords: + texcoords = maybe_prepend_base_path(base_path, texcoords) + MeshCatalog[mesh_info.name] = MeshInfo( + name=mesh_info.name, + data=maybe_prepend_base_path(base_path, mesh_info.data), + geodists=geodists, + symmetry=symmetry, + texcoords=texcoords, + ) + + +def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None: + for mesh_info in mesh_infos: + register_mesh(mesh_info, base_path) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/__init__.py new file mode 100644 index 0000000..7dba87e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .densepose_uniform import DensePoseUniformSampler +from .densepose_confidence_based import DensePoseConfidenceBasedSampler +from .densepose_cse_uniform import DensePoseCSEUniformSampler +from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler +from .mask_from_densepose import MaskFromDensePoseSampler +from .prediction_to_gt import PredictionToGroundTruthSampler diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_base.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_base.py new file mode 100644 index 0000000..4d499d8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_base.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any, Dict, List, Tuple +import torch +from torch.nn import functional as F + +from detectron2.structures import BoxMode, Instances + +from densepose.converters import ToChartResultConverter +from densepose.converters.base import IntTupleBox, make_int_box +from densepose.structures import DensePoseDataRelative, DensePoseList + + +class DensePoseBaseSampler: + """ + Base DensePose sampler to produce DensePose data from DensePose predictions. + Samples for each class are drawn according to some distribution over all pixels estimated + to belong to that class. + """ + + def __init__(self, count_per_class: int = 8): + """ + Constructor + + Args: + count_per_class (int): the sampler produces at most `count_per_class` + samples for each category + """ + self.count_per_class = count_per_class + + def __call__(self, instances: Instances) -> DensePoseList: + """ + Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`) + into DensePose annotations data (an instance of `DensePoseList`) + """ + boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + dp_datas = [] + for i in range(len(boxes_xywh_abs)): + annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i])) + annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6] + instances[i].pred_densepose + ) + dp_datas.append(DensePoseDataRelative(annotation_i)) + # create densepose annotations on CPU + dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size) + return dp_list + + def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: + """ + Sample DensPoseDataRelative from estimation results + """ + labels, dp_result = self._produce_labels_and_results(instance) + annotation = { + DensePoseDataRelative.X_KEY: [], + DensePoseDataRelative.Y_KEY: [], + DensePoseDataRelative.U_KEY: [], + DensePoseDataRelative.V_KEY: [], + DensePoseDataRelative.I_KEY: [], + } + n, h, w = dp_result.shape + for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1): + # indices - tuple of 3 1D tensors of size k + # 0: index along the first dimension N + # 1: index along H dimension + # 2: index along W dimension + indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True) + # values - an array of size [n, k] + # n: number of channels (U, V, confidences) + # k: number of points labeled with part_id + values = dp_result[indices].view(n, -1) + k = values.shape[1] + count = min(self.count_per_class, k) + if count <= 0: + continue + index_sample = self._produce_index_sample(values, count) + sampled_values = values[:, index_sample] + sampled_y = indices[1][index_sample] + 0.5 + sampled_x = indices[2][index_sample] + 0.5 + # prepare / normalize data + x = (sampled_x / w * 256.0).cpu().tolist() + y = (sampled_y / h * 256.0).cpu().tolist() + u = sampled_values[0].clamp(0, 1).cpu().tolist() + v = sampled_values[1].clamp(0, 1).cpu().tolist() + fine_segm_labels = [part_id] * count + # extend annotations + annotation[DensePoseDataRelative.X_KEY].extend(x) + annotation[DensePoseDataRelative.Y_KEY].extend(y) + annotation[DensePoseDataRelative.U_KEY].extend(u) + annotation[DensePoseDataRelative.V_KEY].extend(v) + annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels) + return annotation + + def _produce_index_sample(self, values: torch.Tensor, count: int): + """ + Abstract method to produce a sample of indices to select data + To be implemented in descendants + + Args: + values (torch.Tensor): an array of size [n, k] that contains + estimated values (U, V, confidences); + n: number of channels (U, V, confidences) + k: number of points labeled with part_id + count (int): number of samples to produce, should be positive and <= k + + Return: + list(int): indices of values (along axis 1) selected as a sample + """ + raise NotImplementedError + + def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Method to get labels and DensePose results from an instance + + Args: + instance (Instances): an instance of `DensePoseChartPredictorOutput` + + Return: + labels (torch.Tensor): shape [H, W], DensePose segmentation labels + dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v + """ + converter = ToChartResultConverter + chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes) + labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu() + return labels, dp_result + + def _resample_mask(self, output: Any) -> torch.Tensor: + """ + Convert DensePose predictor output to segmentation annotation - tensors of size + (256, 256) and type `int64`. + + Args: + output: DensePose predictor output with the following attributes: + - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse + segmentation scores + - fine_segm: tensor of size [N, C, H, W] with unnormalized fine + segmentation scores + Return: + Tensor of size (S, S) and type `int64` with coarse segmentation annotations, + where S = DensePoseDataRelative.MASK_SIZE + """ + sz = DensePoseDataRelative.MASK_SIZE + S = ( + F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) + .argmax(dim=1) + .long() + ) + I = ( + ( + F.interpolate( + output.fine_segm, + (sz, sz), + mode="bilinear", + align_corners=False, + ).argmax(dim=1) + * (S > 0).long() + ) + .squeeze() + .cpu() + ) + # Map fine segmentation results to coarse segmentation ground truth + # TODO: extract this into separate classes + # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand, + # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left, + # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left, + # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right, + # 14 = Head + # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand, + # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right, + # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right, + # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left, + # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left, + # 20, 22 = Lower Arm Right, 23, 24 = Head + FINE_TO_COARSE_SEGMENTATION = { + 1: 1, + 2: 1, + 3: 2, + 4: 3, + 5: 4, + 6: 5, + 7: 6, + 8: 7, + 9: 6, + 10: 7, + 11: 8, + 12: 9, + 13: 8, + 14: 9, + 15: 10, + 16: 11, + 17: 10, + 18: 11, + 19: 12, + 20: 13, + 21: 12, + 22: 13, + 23: 14, + 24: 14, + } + mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu")) + for i in range(DensePoseDataRelative.N_PART_LABELS): + mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1] + return mask diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_confidence_based.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_confidence_based.py new file mode 100644 index 0000000..48e325b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_confidence_based.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +from typing import Optional, Tuple +import torch + +from densepose.converters import ToChartResultConverterWithConfidences + +from .densepose_base import DensePoseBaseSampler + + +class DensePoseConfidenceBasedSampler(DensePoseBaseSampler): + """ + Samples DensePose data from DensePose predictions. + Samples for each class are drawn using confidence value estimates. + """ + + def __init__( + self, + confidence_channel: str, + count_per_class: int = 8, + search_count_multiplier: Optional[float] = None, + search_proportion: Optional[float] = None, + ): + """ + Constructor + + Args: + confidence_channel (str): confidence channel to use for sampling; + possible values: + "sigma_2": confidences for UV values + "fine_segm_confidence": confidences for fine segmentation + "coarse_segm_confidence": confidences for coarse segmentation + (default: "sigma_2") + count_per_class (int): the sampler produces at most `count_per_class` + samples for each category (default: 8) + search_count_multiplier (float or None): if not None, the total number + of the most confident estimates of a given class to consider is + defined as `min(search_count_multiplier * count_per_class, N)`, + where `N` is the total number of estimates of the class; cannot be + specified together with `search_proportion` (default: None) + search_proportion (float or None): if not None, the total number of the + of the most confident estimates of a given class to consider is + defined as `min(max(search_proportion * N, count_per_class), N)`, + where `N` is the total number of estimates of the class; cannot be + specified together with `search_count_multiplier` (default: None) + """ + super().__init__(count_per_class) + self.confidence_channel = confidence_channel + self.search_count_multiplier = search_count_multiplier + self.search_proportion = search_proportion + assert (search_count_multiplier is None) or (search_proportion is None), ( + f"Cannot specify both search_count_multiplier (={search_count_multiplier})" + f"and search_proportion (={search_proportion})" + ) + + def _produce_index_sample(self, values: torch.Tensor, count: int): + """ + Produce a sample of indices to select data based on confidences + + Args: + values (torch.Tensor): an array of size [n, k] that contains + estimated values (U, V, confidences); + n: number of channels (U, V, confidences) + k: number of points labeled with part_id + count (int): number of samples to produce, should be positive and <= k + + Return: + list(int): indices of values (along axis 1) selected as a sample + """ + k = values.shape[1] + if k == count: + index_sample = list(range(k)) + else: + # take the best count * search_count_multiplier pixels, + # sample from them uniformly + # (here best = smallest variance) + _, sorted_confidence_indices = torch.sort(values[2]) + if self.search_count_multiplier is not None: + search_count = min(int(count * self.search_count_multiplier), k) + elif self.search_proportion is not None: + search_count = min(max(int(k * self.search_proportion), count), k) + else: + search_count = min(count, k) + sample_from_top = random.sample(range(search_count), count) + index_sample = sorted_confidence_indices[:search_count][sample_from_top] + return index_sample + + def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Method to get labels and DensePose results from an instance, with confidences + + Args: + instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences` + + Return: + labels (torch.Tensor): shape [H, W], DensePose segmentation labels + dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v + stacked with the confidence channel + """ + converter = ToChartResultConverterWithConfidences + chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes) + labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu() + dp_result = torch.cat( + (dp_result, getattr(chart_result, self.confidence_channel)[None].cpu()) + ) + + return labels, dp_result diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_base.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_base.py new file mode 100644 index 0000000..845545c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_base.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any, Dict, List, Tuple +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from densepose.converters.base import IntTupleBox +from densepose.data.utils import get_class_to_mesh_name_mapping +from densepose.modeling.cse.utils import squared_euclidean_distance_matrix +from densepose.structures import DensePoseDataRelative + +from .densepose_base import DensePoseBaseSampler + + +class DensePoseCSEBaseSampler(DensePoseBaseSampler): + """ + Base DensePose sampler to produce DensePose data from DensePose predictions. + Samples for each class are drawn according to some distribution over all pixels estimated + to belong to that class. + """ + + def __init__( + self, + cfg: CfgNode, + use_gt_categories: bool, + embedder: torch.nn.Module, + count_per_class: int = 8, + ): + """ + Constructor + + Args: + cfg (CfgNode): the config of the model + embedder (torch.nn.Module): necessary to compute mesh vertex embeddings + count_per_class (int): the sampler produces at most `count_per_class` + samples for each category + """ + super().__init__(count_per_class) + self.embedder = embedder + self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + self.use_gt_categories = use_gt_categories + + def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: + """ + Sample DensPoseDataRelative from estimation results + """ + if self.use_gt_categories: + instance_class = instance.dataset_classes.tolist()[0] + else: + instance_class = instance.pred_classes.tolist()[0] + mesh_name = self.class_to_mesh_name[instance_class] + + annotation = { + DensePoseDataRelative.X_KEY: [], + DensePoseDataRelative.Y_KEY: [], + DensePoseDataRelative.VERTEX_IDS_KEY: [], + DensePoseDataRelative.MESH_NAME_KEY: mesh_name, + } + + mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh) + indices = torch.nonzero(mask, as_tuple=True) + selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu() + values = other_values[:, indices[0], indices[1]] + k = values.shape[1] + + count = min(self.count_per_class, k) + if count <= 0: + return annotation + + index_sample = self._produce_index_sample(values, count) + closest_vertices = squared_euclidean_distance_matrix( + selected_embeddings[index_sample], self.embedder(mesh_name) + ) + closest_vertices = torch.argmin(closest_vertices, dim=1) + + sampled_y = indices[0][index_sample] + 0.5 + sampled_x = indices[1][index_sample] + 0.5 + # prepare / normalize data + _, _, w, h = bbox_xywh + x = (sampled_x / w * 256.0).cpu().tolist() + y = (sampled_y / h * 256.0).cpu().tolist() + # extend annotations + annotation[DensePoseDataRelative.X_KEY].extend(x) + annotation[DensePoseDataRelative.Y_KEY].extend(y) + annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist()) + return annotation + + def _produce_mask_and_results( + self, instance: Instances, bbox_xywh: IntTupleBox + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Method to get labels and DensePose results from an instance + + Args: + instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput` + bbox_xywh (IntTupleBox): the corresponding bounding box + + Return: + mask (torch.Tensor): shape [H, W], DensePose segmentation mask + embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W], + DensePose CSE Embeddings + other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W], + for potential other values + """ + densepose_output = instance.pred_densepose + S = densepose_output.coarse_segm + E = densepose_output.embedding + _, _, w, h = bbox_xywh + embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0] + coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0] + mask = coarse_segm_resized.argmax(0) > 0 + other_values = torch.empty((0, h, w), device=E.device) + return mask, embeddings, other_values + + def _resample_mask(self, output: Any) -> torch.Tensor: + """ + Convert DensePose predictor output to segmentation annotation - tensors of size + (256, 256) and type `int64`. + + Args: + output: DensePose predictor output with the following attributes: + - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse + segmentation scores + Return: + Tensor of size (S, S) and type `int64` with coarse segmentation annotations, + where S = DensePoseDataRelative.MASK_SIZE + """ + sz = DensePoseDataRelative.MASK_SIZE + mask = ( + F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) + .argmax(dim=1) + .long() + .squeeze() + .cpu() + ) + return mask diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_confidence_based.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_confidence_based.py new file mode 100644 index 0000000..964b7f4 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_confidence_based.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +from typing import Optional, Tuple +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from densepose.converters.base import IntTupleBox + +from .densepose_cse_base import DensePoseCSEBaseSampler + + +class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler): + """ + Samples DensePose data from DensePose predictions. + Samples for each class are drawn using confidence value estimates. + """ + + def __init__( + self, + cfg: CfgNode, + use_gt_categories: bool, + embedder: torch.nn.Module, + confidence_channel: str, + count_per_class: int = 8, + search_count_multiplier: Optional[float] = None, + search_proportion: Optional[float] = None, + ): + """ + Constructor + + Args: + cfg (CfgNode): the config of the model + embedder (torch.nn.Module): necessary to compute mesh vertex embeddings + confidence_channel (str): confidence channel to use for sampling; + possible values: + "coarse_segm_confidence": confidences for coarse segmentation + (default: "coarse_segm_confidence") + count_per_class (int): the sampler produces at most `count_per_class` + samples for each category (default: 8) + search_count_multiplier (float or None): if not None, the total number + of the most confident estimates of a given class to consider is + defined as `min(search_count_multiplier * count_per_class, N)`, + where `N` is the total number of estimates of the class; cannot be + specified together with `search_proportion` (default: None) + search_proportion (float or None): if not None, the total number of the + of the most confident estimates of a given class to consider is + defined as `min(max(search_proportion * N, count_per_class), N)`, + where `N` is the total number of estimates of the class; cannot be + specified together with `search_count_multiplier` (default: None) + """ + super().__init__(cfg, use_gt_categories, embedder, count_per_class) + self.confidence_channel = confidence_channel + self.search_count_multiplier = search_count_multiplier + self.search_proportion = search_proportion + assert (search_count_multiplier is None) or (search_proportion is None), ( + f"Cannot specify both search_count_multiplier (={search_count_multiplier})" + f"and search_proportion (={search_proportion})" + ) + + def _produce_index_sample(self, values: torch.Tensor, count: int): + """ + Produce a sample of indices to select data based on confidences + + Args: + values (torch.Tensor): a tensor of length k that contains confidences + k: number of points labeled with part_id + count (int): number of samples to produce, should be positive and <= k + + Return: + list(int): indices of values (along axis 1) selected as a sample + """ + k = values.shape[1] + if k == count: + index_sample = list(range(k)) + else: + # take the best count * search_count_multiplier pixels, + # sample from them uniformly + # (here best = smallest variance) + _, sorted_confidence_indices = torch.sort(values[0]) + if self.search_count_multiplier is not None: + search_count = min(int(count * self.search_count_multiplier), k) + elif self.search_proportion is not None: + search_count = min(max(int(k * self.search_proportion), count), k) + else: + search_count = min(count, k) + sample_from_top = random.sample(range(search_count), count) + index_sample = sorted_confidence_indices[-search_count:][sample_from_top] + return index_sample + + def _produce_mask_and_results( + self, instance: Instances, bbox_xywh: IntTupleBox + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Method to get labels and DensePose results from an instance + + Args: + instance (Instances): an instance of + `DensePoseEmbeddingPredictorOutputWithConfidences` + bbox_xywh (IntTupleBox): the corresponding bounding box + + Return: + mask (torch.Tensor): shape [H, W], DensePose segmentation mask + embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W] + DensePose CSE Embeddings + other_values: a tensor of shape [1, H, W], DensePose CSE confidence + """ + _, _, w, h = bbox_xywh + densepose_output = instance.pred_densepose + mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh) + other_values = F.interpolate( + getattr(densepose_output, self.confidence_channel), + size=(h, w), + mode="bilinear", + )[0].cpu() + return mask, embeddings, other_values diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_uniform.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_uniform.py new file mode 100644 index 0000000..567636c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_cse_uniform.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .densepose_cse_base import DensePoseCSEBaseSampler +from .densepose_uniform import DensePoseUniformSampler + + +class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler): + """ + Uniform Sampler for CSE + """ + + pass diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_uniform.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_uniform.py new file mode 100644 index 0000000..0d72cc3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/densepose_uniform.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +import torch + +from .densepose_base import DensePoseBaseSampler + + +class DensePoseUniformSampler(DensePoseBaseSampler): + """ + Samples DensePose data from DensePose predictions. + Samples for each class are drawn uniformly over all pixels estimated + to belong to that class. + """ + + def __init__(self, count_per_class: int = 8): + """ + Constructor + + Args: + count_per_class (int): the sampler produces at most `count_per_class` + samples for each category + """ + super().__init__(count_per_class) + + def _produce_index_sample(self, values: torch.Tensor, count: int): + """ + Produce a uniform sample of indices to select data + + Args: + values (torch.Tensor): an array of size [n, k] that contains + estimated values (U, V, confidences); + n: number of channels (U, V, confidences) + k: number of points labeled with part_id + count (int): number of samples to produce, should be positive and <= k + + Return: + list(int): indices of values (along axis 1) selected as a sample + """ + k = values.shape[1] + return random.sample(range(k), count) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/mask_from_densepose.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/mask_from_densepose.py new file mode 100644 index 0000000..0e6e812 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/mask_from_densepose.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.structures import BitMasks, Instances + +from densepose.converters import ToMaskConverter + + +class MaskFromDensePoseSampler: + """ + Produce mask GT from DensePose predictions + This sampler simply converts DensePose predictions to BitMasks + that a contain a bool tensor of the size of the input image + """ + + def __call__(self, instances: Instances) -> BitMasks: + """ + Converts predicted data from `instances` into the GT mask data + + Args: + instances (Instances): predicted results, expected to have `pred_densepose` field + + Returns: + Boolean Tensor of the size of the input image that has non-zero + values at pixels that are estimated to belong to the detected object + """ + return ToMaskConverter.convert( + instances.pred_densepose, instances.pred_boxes, instances.image_size + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/samplers/prediction_to_gt.py b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/prediction_to_gt.py new file mode 100644 index 0000000..3881fa5 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/samplers/prediction_to_gt.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from detectron2.structures import Instances + +ModelOutput = Dict[str, Any] +SampledData = Dict[str, Any] + + +@dataclass +class _Sampler: + """ + Sampler registry entry that contains: + - src (str): source field to sample from (deleted after sampling) + - dst (Optional[str]): destination field to sample to, if not None + - func (Optional[Callable: Any -> Any]): function that performs sampling, + if None, reference copy is performed + """ + + src: str + dst: Optional[str] + func: Optional[Callable[[Any], Any]] + + +class PredictionToGroundTruthSampler: + """ + Sampler implementation that converts predictions to GT using registered + samplers for different fields of `Instances`. + """ + + def __init__(self, dataset_name: str = ""): + self.dataset_name = dataset_name + self._samplers = {} + self.register_sampler("pred_boxes", "gt_boxes", None) + self.register_sampler("pred_classes", "gt_classes", None) + # delete scores + self.register_sampler("scores") + + def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]: + """ + Transform model output into ground truth data through sampling + + Args: + model_output (Dict[str, Any]): model output + Returns: + Dict[str, Any]: sampled data + """ + for model_output_i in model_output: + instances: Instances = model_output_i["instances"] + # transform data in each field + for _, sampler in self._samplers.items(): + if not instances.has(sampler.src) or sampler.dst is None: + continue + if sampler.func is None: + instances.set(sampler.dst, instances.get(sampler.src)) + else: + instances.set(sampler.dst, sampler.func(instances)) + # delete model output data that was transformed + for _, sampler in self._samplers.items(): + if sampler.src != sampler.dst and instances.has(sampler.src): + instances.remove(sampler.src) + model_output_i["dataset"] = self.dataset_name + return model_output + + def register_sampler( + self, + prediction_attr: str, + gt_attr: Optional[str] = None, + func: Optional[Callable[[Any], Any]] = None, + ): + """ + Register sampler for a field + + Args: + prediction_attr (str): field to replace with a sampled value + gt_attr (Optional[str]): field to store the sampled value to, if not None + func (Optional[Callable: Any -> Any]): sampler function + """ + self._samplers[(prediction_attr, gt_attr)] = _Sampler( + src=prediction_attr, dst=gt_attr, func=func + ) + + def remove_sampler( + self, + prediction_attr: str, + gt_attr: Optional[str] = None, + ): + """ + Remove sampler for a field + + Args: + prediction_attr (str): field to replace with a sampled value + gt_attr (Optional[str]): field to store the sampled value to, if not None + """ + assert (prediction_attr, gt_attr) in self._samplers + del self._samplers[(prediction_attr, gt_attr)] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/transform/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/transform/__init__.py new file mode 100644 index 0000000..369e1b2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/transform/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .image import ImageResizeTransform diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/transform/image.py b/data_processing/detectron2/projects/DensePose/densepose/data/transform/image.py new file mode 100644 index 0000000..8139b67 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/transform/image.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch + + +class ImageResizeTransform: + """ + Transform that resizes images loaded from a dataset + (BGR data in NCHW channel order, typically uint8) to a format ready to be + consumed by DensePose training (BGR float32 data in NCHW channel order) + """ + + def __init__(self, min_size: int = 800, max_size: int = 1333): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, images: torch.Tensor) -> torch.Tensor: + """ + Args: + images (torch.Tensor): tensor of size [N, 3, H, W] that contains + BGR data (typically in uint8) + Returns: + images (torch.Tensor): tensor of size [N, 3, H1, W1] where + H1 and W1 are chosen to respect the specified min and max sizes + and preserve the original aspect ratio, the data channels + follow BGR order and the data type is `torch.float32` + """ + # resize with min size + images = images.float() + min_size = min(images.shape[-2:]) + max_size = max(images.shape[-2:]) + scale = min(self.min_size / min_size, self.max_size / max_size) + images = torch.nn.functional.interpolate( + images, + scale_factor=scale, + mode="bilinear", + align_corners=False, + ) + return images diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/utils.py b/data_processing/detectron2/projects/DensePose/densepose/data/utils.py new file mode 100644 index 0000000..9878c31 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +from typing import Dict, Optional + +from detectron2.config import CfgNode + + +def is_relative_local_path(path: str) -> bool: + path_str = os.fsdecode(path) + return ("://" not in path_str) and not os.path.isabs(path) + + +def maybe_prepend_base_path(base_path: Optional[str], path: str): + """ + Prepends the provided path with a base path prefix if: + 1) base path is not None; + 2) path is a local path + """ + if base_path is None: + return path + if is_relative_local_path(path): + return os.path.join(base_path, path) + return path + + +def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]: + return { + int(class_id): mesh_name + for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items() + } + + +def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]: + return { + category: int(class_id) + for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items() + } diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/video/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/data/video/__init__.py new file mode 100644 index 0000000..72406e1 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/video/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .frame_selector import ( + FrameSelectionStrategy, + RandomKFramesSelector, + FirstKFramesSelector, + LastKFramesSelector, + FrameTsList, + FrameSelector, +) + +from .video_keyframe_dataset import ( + VideoKeyframeDataset, + video_list_from_file, + list_keyframes, + read_keyframes, +) diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/video/frame_selector.py b/data_processing/detectron2/projects/DensePose/densepose/data/video/frame_selector.py new file mode 100644 index 0000000..c28f0e9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/video/frame_selector.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +from collections.abc import Callable +from enum import Enum +from typing import Callable as TCallable +from typing import List + +FrameTsList = List[int] +FrameSelector = TCallable[[FrameTsList], FrameTsList] + + +class FrameSelectionStrategy(Enum): + """ + Frame selection strategy used with videos: + - "random_k": select k random frames + - "first_k": select k first frames + - "last_k": select k last frames + - "all": select all frames + """ + + # fmt: off + RANDOM_K = "random_k" + FIRST_K = "first_k" + LAST_K = "last_k" + ALL = "all" + # fmt: on + + +class RandomKFramesSelector(Callable): # pyre-ignore[39] + """ + Selector that retains at most `k` random frames + """ + + def __init__(self, k: int): + self.k = k + + def __call__(self, frame_tss: FrameTsList) -> FrameTsList: + """ + Select `k` random frames + + Args: + frames_tss (List[int]): timestamps of input frames + Returns: + List[int]: timestamps of selected frames + """ + return random.sample(frame_tss, min(self.k, len(frame_tss))) + + +class FirstKFramesSelector(Callable): # pyre-ignore[39] + """ + Selector that retains at most `k` first frames + """ + + def __init__(self, k: int): + self.k = k + + def __call__(self, frame_tss: FrameTsList) -> FrameTsList: + """ + Select `k` first frames + + Args: + frames_tss (List[int]): timestamps of input frames + Returns: + List[int]: timestamps of selected frames + """ + return frame_tss[: self.k] + + +class LastKFramesSelector(Callable): # pyre-ignore[39] + """ + Selector that retains at most `k` last frames from video data + """ + + def __init__(self, k: int): + self.k = k + + def __call__(self, frame_tss: FrameTsList) -> FrameTsList: + """ + Select `k` last frames + + Args: + frames_tss (List[int]): timestamps of input frames + Returns: + List[int]: timestamps of selected frames + """ + return frame_tss[-self.k :] diff --git a/data_processing/detectron2/projects/DensePose/densepose/data/video/video_keyframe_dataset.py b/data_processing/detectron2/projects/DensePose/densepose/data/video/video_keyframe_dataset.py new file mode 100644 index 0000000..214365c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/data/video/video_keyframe_dataset.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import csv +import logging +import numpy as np +from typing import Any, Callable, Dict, List, Optional, Union +import av +import torch +from torch.utils.data.dataset import Dataset + +from detectron2.utils.file_io import PathManager + +from ..utils import maybe_prepend_base_path +from .frame_selector import FrameSelector, FrameTsList + +FrameList = List[av.frame.Frame] # pyre-ignore[16] +FrameTransform = Callable[[torch.Tensor], torch.Tensor] + + +def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList: + """ + Traverses all keyframes of a video file. Returns a list of keyframe + timestamps. Timestamps are counts in timebase units. + + Args: + video_fpath (str): Video file path + video_stream_idx (int): Video stream index (default: 0) + Returns: + List[int]: list of keyframe timestaps (timestamp is a count in timebase + units) + """ + try: + with PathManager.open(video_fpath, "rb") as io: + container = av.open(io, mode="r") + stream = container.streams.video[video_stream_idx] + keyframes = [] + pts = -1 + # Note: even though we request forward seeks for keyframes, sometimes + # a keyframe in backwards direction is returned. We introduce tolerance + # as a max count of ignored backward seeks + tolerance_backward_seeks = 2 + while True: + try: + container.seek(pts + 1, backward=False, any_frame=False, stream=stream) + except av.AVError as e: + # the exception occurs when the video length is exceeded, + # we then return whatever data we've already collected + logger = logging.getLogger(__name__) + logger.debug( + f"List keyframes: Error seeking video file {video_fpath}, " + f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}" + ) + return keyframes + except OSError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"List keyframes: Error seeking video file {video_fpath}, " + f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}" + ) + return [] + packet = next(container.demux(video=video_stream_idx)) + if packet.pts is not None and packet.pts <= pts: + logger = logging.getLogger(__name__) + logger.warning( + f"Video file {video_fpath}, stream {video_stream_idx}: " + f"bad seek for packet {pts + 1} (got packet {packet.pts}), " + f"tolerance {tolerance_backward_seeks}." + ) + tolerance_backward_seeks -= 1 + if tolerance_backward_seeks == 0: + return [] + pts += 1 + continue + tolerance_backward_seeks = 2 + pts = packet.pts + if pts is None: + return keyframes + if packet.is_keyframe: + keyframes.append(pts) + return keyframes + except OSError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}" + ) + except RuntimeError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"List keyframes: Error opening video file container {video_fpath}, " + f"Runtime error: {e}" + ) + return [] + + +def read_keyframes( + video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0 +) -> FrameList: # pyre-ignore[11] + """ + Reads keyframe data from a video file. + + Args: + video_fpath (str): Video file path + keyframes (List[int]): List of keyframe timestamps (as counts in + timebase units to be used in container seek operations) + video_stream_idx (int): Video stream index (default: 0) + Returns: + List[Frame]: list of frames that correspond to the specified timestamps + """ + try: + with PathManager.open(video_fpath, "rb") as io: + container = av.open(io) + stream = container.streams.video[video_stream_idx] + frames = [] + for pts in keyframes: + try: + container.seek(pts, any_frame=False, stream=stream) + frame = next(container.decode(video=0)) + frames.append(frame) + except av.AVError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"Read keyframes: Error seeking video file {video_fpath}, " + f"video stream {video_stream_idx}, pts {pts}, AV error: {e}" + ) + container.close() + return frames + except OSError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"Read keyframes: Error seeking video file {video_fpath}, " + f"video stream {video_stream_idx}, pts {pts}, OS error: {e}" + ) + container.close() + return frames + except StopIteration: + logger = logging.getLogger(__name__) + logger.warning( + f"Read keyframes: Error decoding frame from {video_fpath}, " + f"video stream {video_stream_idx}, pts {pts}" + ) + container.close() + return frames + + container.close() + return frames + except OSError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}" + ) + except RuntimeError as e: + logger = logging.getLogger(__name__) + logger.warning( + f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}" + ) + return [] + + +def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None): + """ + Create a list of paths to video files from a text file. + + Args: + video_list_fpath (str): path to a plain text file with the list of videos + base_path (str): base path for entries from the video list (default: None) + """ + video_list = [] + with PathManager.open(video_list_fpath, "r") as io: + for line in io: + video_list.append(maybe_prepend_base_path(base_path, str(line.strip()))) + return video_list + + +def read_keyframe_helper_data(fpath: str): + """ + Read keyframe data from a file in CSV format: the header should contain + "video_id" and "keyframes" fields. Value specifications are: + video_id: int + keyframes: list(int) + Example of contents: + video_id,keyframes + 2,"[1,11,21,31,41,51,61,71,81]" + + Args: + fpath (str): File containing keyframe data + + Return: + video_id_to_keyframes (dict: int -> list(int)): for a given video ID it + contains a list of keyframes for that video + """ + video_id_to_keyframes = {} + try: + with PathManager.open(fpath, "r") as io: + csv_reader = csv.reader(io) # pyre-ignore[6] + header = next(csv_reader) + video_id_idx = header.index("video_id") + keyframes_idx = header.index("keyframes") + for row in csv_reader: + video_id = int(row[video_id_idx]) + assert ( + video_id not in video_id_to_keyframes + ), f"Duplicate keyframes entry for video {fpath}" + video_id_to_keyframes[video_id] = ( + [int(v) for v in row[keyframes_idx][1:-1].split(",")] + if len(row[keyframes_idx]) > 2 + else [] + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.warning(f"Error reading keyframe helper data from {fpath}: {e}") + return video_id_to_keyframes + + +class VideoKeyframeDataset(Dataset): + """ + Dataset that provides keyframes for a set of videos. + """ + + _EMPTY_FRAMES = torch.empty((0, 3, 1, 1)) + + def __init__( + self, + video_list: List[str], + category_list: Union[str, List[str], None] = None, + frame_selector: Optional[FrameSelector] = None, + transform: Optional[FrameTransform] = None, + keyframe_helper_fpath: Optional[str] = None, + ): + """ + Dataset constructor + + Args: + video_list (List[str]): list of paths to video files + category_list (Union[str, List[str], None]): list of animal categories for each + video file. If it is a string, or None, this applies to all videos + frame_selector (Callable: KeyFrameList -> KeyFrameList): + selects keyframes to process, keyframes are given by + packet timestamps in timebase counts. If None, all keyframes + are selected (default: None) + transform (Callable: torch.Tensor -> torch.Tensor): + transforms a batch of RGB images (tensors of size [B, 3, H, W]), + returns a tensor of the same size. If None, no transform is + applied (default: None) + + """ + if type(category_list) == list: + self.category_list = category_list + else: + self.category_list = [category_list] * len(video_list) + assert len(video_list) == len( + self.category_list + ), "length of video and category lists must be equal" + self.video_list = video_list + self.frame_selector = frame_selector + self.transform = transform + self.keyframe_helper_data = ( + read_keyframe_helper_data(keyframe_helper_fpath) + if keyframe_helper_fpath is not None + else None + ) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Gets selected keyframes from a given video + + Args: + idx (int): video index in the video list file + Returns: + A dictionary containing two keys: + images (torch.Tensor): tensor of size [N, H, W, 3] or of size + defined by the transform that contains keyframes data + categories (List[str]): categories of the frames + """ + categories = [self.category_list[idx]] + fpath = self.video_list[idx] + keyframes = ( + list_keyframes(fpath) + if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data + else self.keyframe_helper_data[idx] + ) + transform = self.transform + frame_selector = self.frame_selector + if not keyframes: + return {"images": self._EMPTY_FRAMES, "categories": []} + if frame_selector is not None: + keyframes = frame_selector(keyframes) + frames = read_keyframes(fpath, keyframes) + if not frames: + return {"images": self._EMPTY_FRAMES, "categories": []} + frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames]) + frames = torch.as_tensor(frames, device=torch.device("cpu")) + frames = frames[..., [2, 1, 0]] # RGB -> BGR + frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW + if transform is not None: + frames = transform(frames) + return {"images": frames, "categories": categories} + + def __len__(self): + return len(self.video_list) diff --git a/data_processing/detectron2/projects/DensePose/densepose/engine/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/engine/__init__.py new file mode 100644 index 0000000..539b93a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/engine/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .trainer import Trainer diff --git a/data_processing/detectron2/projects/DensePose/densepose/engine/trainer.py b/data_processing/detectron2/projects/DensePose/densepose/engine/trainer.py new file mode 100644 index 0000000..a8ffe82 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/engine/trainer.py @@ -0,0 +1,258 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +import os +from collections import OrderedDict +from typing import List, Optional, Union +import torch +from torch import nn + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import CfgNode +from detectron2.engine import DefaultTrainer +from detectron2.evaluation import ( + DatasetEvaluator, + DatasetEvaluators, + inference_on_dataset, + print_csv_format, +) +from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping +from detectron2.utils import comm +from detectron2.utils.events import EventWriter, get_event_storage + +from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg +from densepose.data import ( + DatasetMapper, + build_combined_loader, + build_detection_test_loader, + build_detection_train_loader, + build_inference_based_loaders, + has_inference_based_loaders, +) +from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter +from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage +from densepose.modeling.cse import Embedder + + +class SampleCountingLoader: + def __init__(self, loader): + self.loader = loader + + def __iter__(self): + it = iter(self.loader) + storage = get_event_storage() + while True: + try: + batch = next(it) + num_inst_per_dataset = {} + for data in batch: + dataset_name = data["dataset"] + if dataset_name not in num_inst_per_dataset: + num_inst_per_dataset[dataset_name] = 0 + num_inst = len(data["instances"]) + num_inst_per_dataset[dataset_name] += num_inst + for dataset_name in num_inst_per_dataset: + storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name]) + yield batch + except StopIteration: + break + + +class SampleCountMetricPrinter(EventWriter): + def __init__(self): + self.logger = logging.getLogger(__name__) + + def write(self): + storage = get_event_storage() + batch_stats_strs = [] + for key, buf in storage.histories().items(): + if key.startswith("batch/"): + batch_stats_strs.append(f"{key} {buf.avg(20)}") + self.logger.info(", ".join(batch_stats_strs)) + + +class Trainer(DefaultTrainer): + @classmethod + def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]: + if isinstance(model, nn.parallel.DistributedDataParallel): + model = model.module + if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"): + return model.roi_heads.embedder + return None + + # TODO: the only reason to copy the base class code here is to pass the embedder from + # the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting + @classmethod + def test( + cls, + cfg: CfgNode, + model: nn.Module, + evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None, + ): + """ + Args: + cfg (CfgNode): + model (nn.Module): + evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call + :meth:`build_evaluator`. Otherwise, must have the same length as + ``cfg.DATASETS.TEST``. + + Returns: + dict: a dict of result metrics + """ + logger = logging.getLogger(__name__) + if isinstance(evaluators, DatasetEvaluator): + evaluators = [evaluators] + if evaluators is not None: + assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( + len(cfg.DATASETS.TEST), len(evaluators) + ) + + results = OrderedDict() + for idx, dataset_name in enumerate(cfg.DATASETS.TEST): + data_loader = cls.build_test_loader(cfg, dataset_name) + # When evaluators are passed in as arguments, + # implicitly assume that evaluators can be created before data_loader. + if evaluators is not None: + evaluator = evaluators[idx] + else: + try: + embedder = cls.extract_embedder_from_model(model) + evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder) + except NotImplementedError: + logger.warn( + "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " + "or implement its `build_evaluator` method." + ) + results[dataset_name] = {} + continue + if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process(): + results_i = inference_on_dataset(model, data_loader, evaluator) + else: + results_i = {} + results[dataset_name] = results_i + if comm.is_main_process(): + assert isinstance( + results_i, dict + ), "Evaluator must return a dict on the main process. Got {} instead.".format( + results_i + ) + logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + print_csv_format(results_i) + + if len(results) == 1: + results = list(results.values())[0] + return results + + @classmethod + def build_evaluator( + cls, + cfg: CfgNode, + dataset_name: str, + output_folder: Optional[str] = None, + embedder: Optional[Embedder] = None, + ) -> DatasetEvaluators: + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluators = [] + distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE + # Note: we currently use COCO evaluator for both COCO and LVIS datasets + # to have compatible metrics. LVIS bbox evaluator could also be used + # with an adapter to properly handle filtered / mapped categories + # evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + # if evaluator_type == "coco": + # evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + # elif evaluator_type == "lvis": + # evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder)) + evaluators.append( + Detectron2COCOEvaluatorAdapter( + dataset_name, output_dir=output_folder, distributed=distributed + ) + ) + if cfg.MODEL.DENSEPOSE_ON: + storage = build_densepose_evaluator_storage(cfg, output_folder) + evaluators.append( + DensePoseCOCOEvaluator( + dataset_name, + distributed, + output_folder, + evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE, + min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD, + storage=storage, + embedder=embedder, + should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT, + mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES, + ) + ) + return DatasetEvaluators(evaluators) + + @classmethod + def build_optimizer(cls, cfg: CfgNode, model: nn.Module): + params = get_default_optimizer_params( + model, + base_lr=cfg.SOLVER.BASE_LR, + weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, + bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, + weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, + overrides={ + "features": { + "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR, + }, + "embeddings": { + "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR, + }, + }, + ) + optimizer = torch.optim.SGD( + params, + cfg.SOLVER.BASE_LR, + momentum=cfg.SOLVER.MOMENTUM, + nesterov=cfg.SOLVER.NESTEROV, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + ) + # pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`. + return maybe_add_gradient_clipping(cfg, optimizer) + + @classmethod + def build_test_loader(cls, cfg: CfgNode, dataset_name): + return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False)) + + @classmethod + def build_train_loader(cls, cfg: CfgNode): + data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True)) + if not has_inference_based_loaders(cfg): + return data_loader + model = cls.build_model(cfg) + model.to(cfg.BOOTSTRAP_MODEL.DEVICE) + DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False) + inference_based_loaders, ratios = build_inference_based_loaders(cfg, model) + loaders = [data_loader] + inference_based_loaders + ratios = [1.0] + ratios + combined_data_loader = build_combined_loader(cfg, loaders, ratios) + sample_counting_loader = SampleCountingLoader(combined_data_loader) + return sample_counting_loader + + def build_writers(self): + writers = super().build_writers() + writers.append(SampleCountMetricPrinter()) + return writers + + @classmethod + def test_with_TTA(cls, cfg: CfgNode, model): + logger = logging.getLogger("detectron2.trainer") + # In the end of training, run an evaluation with TTA + # Only support some R-CNN models. + logger.info("Running inference with test-time augmentation ...") + transform_data = load_from_cfg(cfg) + model = DensePoseGeneralizedRCNNWithTTA( + cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg) + ) + evaluators = [ + cls.build_evaluator( + cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") + ) + for name in cfg.DATASETS.TEST + ] + res = cls.test(cfg, model, evaluators) # pyre-ignore[6] + res = OrderedDict({k + "_TTA": v for k, v in res.items()}) + return res diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/__init__.py new file mode 100644 index 0000000..e5ae1f2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .evaluator import DensePoseCOCOEvaluator diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/d2_evaluator_adapter.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/d2_evaluator_adapter.py new file mode 100644 index 0000000..1fbc526 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/d2_evaluator_adapter.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.data.catalog import Metadata +from detectron2.evaluation import COCOEvaluator + +from densepose.data.datasets.coco import ( + get_contiguous_id_to_category_id_map, + maybe_filter_categories_cocoapi, +) + + +def _maybe_add_iscrowd_annotations(cocoapi) -> None: + for ann in cocoapi.dataset["annotations"]: + if "iscrowd" not in ann: + ann["iscrowd"] = 0 + + +class Detectron2COCOEvaluatorAdapter(COCOEvaluator): + def __init__( + self, + dataset_name, + output_dir=None, + distributed=True, + ): + super().__init__(dataset_name, output_dir=output_dir, distributed=distributed) + maybe_filter_categories_cocoapi(dataset_name, self._coco_api) + _maybe_add_iscrowd_annotations(self._coco_api) + # substitute category metadata to account for categories + # that are mapped to the same contiguous id + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + self._maybe_substitute_metadata() + + def _maybe_substitute_metadata(self): + cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata) + cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id + if len(cont_id_2_cat_id) == len(cat_id_2_cont_id): + return + + cat_id_2_cont_id_injective = {} + for cat_id, cont_id in cat_id_2_cont_id.items(): + if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id): + cat_id_2_cont_id_injective[cat_id] = cont_id + + metadata_new = Metadata(name=self._metadata.name) + for key, value in self._metadata.__dict__.items(): + if key == "thing_dataset_id_to_contiguous_id": + setattr(metadata_new, key, cat_id_2_cont_id_injective) + else: + setattr(metadata_new, key, value) + self._metadata = metadata_new diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/densepose_coco_evaluation.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/densepose_coco_evaluation.py new file mode 100644 index 0000000..06965f3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/densepose_coco_evaluation.py @@ -0,0 +1,1303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# This is a modified version of cocoeval.py where we also have the densepose evaluation. + +__author__ = "tsungyi" + +import copy +import datetime +import logging +import numpy as np +import pickle +import time +from collections import defaultdict +from enum import Enum +from typing import Any, Dict, Tuple +import scipy.spatial.distance as ssd +import torch +import torch.nn.functional as F +from pycocotools import mask as maskUtils +from scipy.io import loadmat +from scipy.ndimage import zoom as spzoom + +from detectron2.utils.file_io import PathManager + +from densepose.converters.chart_output_to_chart_result import resample_uv_tensors_to_bbox +from densepose.converters.segm_to_mask import ( + resample_coarse_segm_tensor_to_bbox, + resample_fine_and_coarse_segm_tensors_to_bbox, +) +from densepose.modeling.cse.utils import squared_euclidean_distance_matrix +from densepose.structures import DensePoseDataRelative +from densepose.structures.mesh import create_mesh + +logger = logging.getLogger(__name__) + + +class DensePoseEvalMode(str, Enum): + # use both masks and geodesic distances (GPS * IOU) to compute scores + GPSM = "gpsm" + # use only geodesic distances (GPS) to compute scores + GPS = "gps" + # use only masks (IOU) to compute scores + IOU = "iou" + + +class DensePoseDataMode(str, Enum): + # use estimated IUV data (default mode) + IUV_DT = "iuvdt" + # use ground truth IUV data + IUV_GT = "iuvgt" + # use ground truth labels I and set UV to 0 + I_GT_UV_0 = "igtuv0" + # use ground truth labels I and estimated UV coordinates + I_GT_UV_DT = "igtuvdt" + # use estimated labels I and set UV to 0 + I_DT_UV_0 = "idtuv0" + + +class DensePoseCocoEval(object): + # Interface for evaluating detection on the Microsoft COCO dataset. + # + # The usage for CocoEval is as follows: + # cocoGt=..., cocoDt=... # load dataset and results + # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object + # E.params.recThrs = ...; # set parameters as desired + # E.evaluate(); # run per image evaluation + # E.accumulate(); # accumulate per image results + # E.summarize(); # display summary metrics of results + # For example usage see evalDemo.m and http://mscoco.org/. + # + # The evaluation parameters are as follows (defaults in brackets): + # imgIds - [all] N img ids to use for evaluation + # catIds - [all] K cat ids to use for evaluation + # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation + # recThrs - [0:.01:1] R=101 recall thresholds for evaluation + # areaRng - [...] A=4 object area ranges for evaluation + # maxDets - [1 10 100] M=3 thresholds on max detections per image + # iouType - ['segm'] set iouType to 'segm', 'bbox', 'keypoints' or 'densepose' + # iouType replaced the now DEPRECATED useSegm parameter. + # useCats - [1] if true use category labels for evaluation + # Note: if useCats=0 category labels are ignored as in proposal scoring. + # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified. + # + # evaluate(): evaluates detections on every image and every category and + # concats the results into the "evalImgs" with fields: + # dtIds - [1xD] id for each of the D detections (dt) + # gtIds - [1xG] id for each of the G ground truths (gt) + # dtMatches - [TxD] matching gt id at each IoU or 0 + # gtMatches - [TxG] matching dt id at each IoU or 0 + # dtScores - [1xD] confidence of each dt + # gtIgnore - [1xG] ignore flag for each gt + # dtIgnore - [TxD] ignore flag for each dt at each IoU + # + # accumulate(): accumulates the per-image, per-category evaluation + # results in "evalImgs" into the dictionary "eval" with fields: + # params - parameters used for evaluation + # date - date evaluation was performed + # counts - [T,R,K,A,M] parameter dimensions (see above) + # precision - [TxRxKxAxM] precision for every evaluation setting + # recall - [TxKxAxM] max recall for every evaluation setting + # Note: precision and recall==-1 for settings with no gt objects. + # + # See also coco, mask, pycocoDemo, pycocoEvalDemo + # + # Microsoft COCO Toolbox. version 2.0 + # Data, paper, and tutorials available at: http://mscoco.org/ + # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. + # Licensed under the Simplified BSD License [see coco/license.txt] + def __init__( + self, + cocoGt=None, + cocoDt=None, + iouType: str = "densepose", + multi_storage=None, + embedder=None, + dpEvalMode: DensePoseEvalMode = DensePoseEvalMode.GPS, + dpDataMode: DensePoseDataMode = DensePoseDataMode.IUV_DT, + ): + """ + Initialize CocoEval using coco APIs for gt and dt + :param cocoGt: coco object with ground truth annotations + :param cocoDt: coco object with detection results + :return: None + """ + self.cocoGt = cocoGt # ground truth COCO API + self.cocoDt = cocoDt # detections COCO API + self.multi_storage = multi_storage + self.embedder = embedder + self._dpEvalMode = dpEvalMode + self._dpDataMode = dpDataMode + self.evalImgs = defaultdict(list) # per-image per-category eval results [KxAxI] + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Params(iouType=iouType) # parameters + self._paramsEval = {} # parameters for evaluation + self.stats = [] # result summarization + self.ious = {} # ious between all gts and dts + if cocoGt is not None: + self.params.imgIds = sorted(cocoGt.getImgIds()) + self.params.catIds = sorted(cocoGt.getCatIds()) + self.ignoreThrBB = 0.7 + self.ignoreThrUV = 0.9 + + def _loadGEval(self): + smpl_subdiv_fpath = PathManager.get_local_path( + "https://dl.fbaipublicfiles.com/densepose/data/SMPL_subdiv.mat" + ) + pdist_transform_fpath = PathManager.get_local_path( + "https://dl.fbaipublicfiles.com/densepose/data/SMPL_SUBDIV_TRANSFORM.mat" + ) + pdist_matrix_fpath = PathManager.get_local_path( + "https://dl.fbaipublicfiles.com/densepose/data/Pdist_matrix.pkl", timeout_sec=120 + ) + SMPL_subdiv = loadmat(smpl_subdiv_fpath) + self.PDIST_transform = loadmat(pdist_transform_fpath) + self.PDIST_transform = self.PDIST_transform["index"].squeeze() + UV = np.array([SMPL_subdiv["U_subdiv"], SMPL_subdiv["V_subdiv"]]).squeeze() + ClosestVertInds = np.arange(UV.shape[1]) + 1 + self.Part_UVs = [] + self.Part_ClosestVertInds = [] + for i in np.arange(24): + self.Part_UVs.append(UV[:, SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)]) + self.Part_ClosestVertInds.append( + ClosestVertInds[SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)] + ) + + with open(pdist_matrix_fpath, "rb") as hFile: + arrays = pickle.load(hFile, encoding="latin1") + self.Pdist_matrix = arrays["Pdist_matrix"] + self.Part_ids = np.array(SMPL_subdiv["Part_ID_subdiv"].squeeze()) + # Mean geodesic distances for parts. + self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150]) + # Coarse Part labels. + self.CoarseParts = np.array( + [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8] + ) + + def _prepare(self): + """ + Prepare ._gts and ._dts for evaluation based on params + :return: None + """ + + def _toMask(anns, coco): + # modify ann['segmentation'] by reference + for ann in anns: + # safeguard for invalid segmentation annotation; + # annotations containing empty lists exist in the posetrack + # dataset. This is not a correct segmentation annotation + # in terms of COCO format; we need to deal with it somehow + segm = ann["segmentation"] + if type(segm) == list and len(segm) == 0: + ann["segmentation"] = None + continue + rle = coco.annToRLE(ann) + ann["segmentation"] = rle + + def _getIgnoreRegion(iid, coco): + img = coco.imgs[iid] + + if "ignore_regions_x" not in img.keys(): + return None + + if len(img["ignore_regions_x"]) == 0: + return None + + rgns_merged = [ + [v for xy in zip(region_x, region_y) for v in xy] + for region_x, region_y in zip(img["ignore_regions_x"], img["ignore_regions_y"]) + ] + rles = maskUtils.frPyObjects(rgns_merged, img["height"], img["width"]) + rle = maskUtils.merge(rles) + return maskUtils.decode(rle) + + def _checkIgnore(dt, iregion): + if iregion is None: + return True + + bb = np.array(dt["bbox"]).astype(np.int) + x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3] + x2 = min([x2, iregion.shape[1]]) + y2 = min([y2, iregion.shape[0]]) + + if bb[2] * bb[3] == 0: + return False + + crop_iregion = iregion[y1:y2, x1:x2] + + if crop_iregion.sum() == 0: + return True + + if "densepose" not in dt.keys(): # filtering boxes + return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB + + # filtering UVs + ignoremask = np.require(crop_iregion, requirements=["F"]) + mask = self._extract_mask(dt) + uvmask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"]) + uvmask_ = maskUtils.encode(uvmask) + ignoremask_ = maskUtils.encode(ignoremask) + uviou = maskUtils.iou([uvmask_], [ignoremask_], [1])[0] + return uviou < self.ignoreThrUV + + p = self.params + + if p.useCats: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + imns = self.cocoGt.loadImgs(p.imgIds) + self.size_mapping = {} + for im in imns: + self.size_mapping[im["id"]] = [im["height"], im["width"]] + + # if iouType == 'uv', add point gt annotations + if p.iouType == "densepose": + self._loadGEval() + + # convert ground truth to mask if iouType == 'segm' + if p.iouType == "segm": + _toMask(gts, self.cocoGt) + _toMask(dts, self.cocoDt) + + # set ignore flag + for gt in gts: + gt["ignore"] = gt["ignore"] if "ignore" in gt else 0 + gt["ignore"] = "iscrowd" in gt and gt["iscrowd"] + if p.iouType == "keypoints": + gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"] + if p.iouType == "densepose": + gt["ignore"] = ("dp_x" in gt) == 0 + if p.iouType == "segm": + gt["ignore"] = gt["segmentation"] is None + + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self._igrgns = defaultdict(list) + + for gt in gts: + iid = gt["image_id"] + if iid not in self._igrgns.keys(): + self._igrgns[iid] = _getIgnoreRegion(iid, self.cocoGt) + if _checkIgnore(gt, self._igrgns[iid]): + self._gts[iid, gt["category_id"]].append(gt) + for dt in dts: + iid = dt["image_id"] + if (iid not in self._igrgns) or _checkIgnore(dt, self._igrgns[iid]): + self._dts[iid, dt["category_id"]].append(dt) + + self.evalImgs = defaultdict(list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + def evaluate(self): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + """ + tic = time.time() + logger.info("Running per image DensePose evaluation... {}".format(self.params.iouType)) + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + logger.info("useSegm (deprecated) is not None. Running DensePose evaluation") + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType in ["segm", "bbox"]: + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + elif p.iouType == "densepose": + computeIoU = self.computeOgps + if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}: + self.real_ious = { + (imgId, catId): self.computeDPIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } + + self.ious = { + (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds + } + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + self.evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + self._paramsEval = copy.deepcopy(self.params) + toc = time.time() + logger.info("DensePose evaluation DONE (t={:0.2f}s).".format(toc - tic)) + + def getDensePoseMask(self, polys): + maskGen = np.zeros([256, 256]) + stop = min(len(polys) + 1, 15) + for i in range(1, stop): + if polys[i - 1]: + currentMask = maskUtils.decode(polys[i - 1]) + maskGen[currentMask > 0] = i + return maskGen + + def _generate_rlemask_on_image(self, mask, imgId, data): + bbox_xywh = np.array(data["bbox"]) + x, y, w, h = bbox_xywh + im_h, im_w = self.size_mapping[imgId] + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + if mask is not None: + x0 = max(int(x), 0) + x1 = min(int(x + w), im_w, int(x) + mask.shape[1]) + y0 = max(int(y), 0) + y1 = min(int(y + h), im_h, int(y) + mask.shape[0]) + y = int(y) + x = int(x) + im_mask[y0:y1, x0:x1] = mask[y0 - y : y1 - y, x0 - x : x1 - x] + im_mask = np.require(np.asarray(im_mask > 0), dtype=np.uint8, requirements=["F"]) + rle_mask = maskUtils.encode(np.array(im_mask[:, :, np.newaxis], order="F"))[0] + return rle_mask + + def computeDPIoU(self, imgId, catId): + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + gtmasks = [] + for g in gt: + if DensePoseDataRelative.S_KEY in g: + # convert DensePose mask to a binary mask + mask = np.minimum(self.getDensePoseMask(g[DensePoseDataRelative.S_KEY]), 1.0) + _, _, w, h = g["bbox"] + scale_x = float(max(w, 1)) / mask.shape[1] + scale_y = float(max(h, 1)) / mask.shape[0] + mask = spzoom(mask, (scale_y, scale_x), order=1, prefilter=False) + mask = np.array(mask > 0.5, dtype=np.uint8) + rle_mask = self._generate_rlemask_on_image(mask, imgId, g) + elif "segmentation" in g: + segmentation = g["segmentation"] + if isinstance(segmentation, list) and segmentation: + # polygons + im_h, im_w = self.size_mapping[imgId] + rles = maskUtils.frPyObjects(segmentation, im_h, im_w) + rle_mask = maskUtils.merge(rles) + elif isinstance(segmentation, dict): + if isinstance(segmentation["counts"], list): + # uncompressed RLE + im_h, im_w = self.size_mapping[imgId] + rle_mask = maskUtils.frPyObjects(segmentation, im_h, im_w) + else: + # compressed RLE + rle_mask = segmentation + else: + rle_mask = self._generate_rlemask_on_image(None, imgId, g) + else: + rle_mask = self._generate_rlemask_on_image(None, imgId, g) + gtmasks.append(rle_mask) + + dtmasks = [] + for d in dt: + mask = self._extract_mask(d) + mask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"]) + rle_mask = self._generate_rlemask_on_image(mask, imgId, d) + dtmasks.append(rle_mask) + + # compute iou between each dt and gt region + iscrowd = [int(o.get("iscrowd", 0)) for o in gt] + iousDP = maskUtils.iou(dtmasks, gtmasks, iscrowd) + return iousDP + + def computeIoU(self, imgId, catId): + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + if p.iouType == "segm": + g = [g["segmentation"] for g in gt if g["segmentation"] is not None] + d = [d["segmentation"] for d in dt if d["segmentation"] is not None] + elif p.iouType == "bbox": + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + else: + raise Exception("unknown iouType for iou computation") + + # compute iou between each dt and gt region + iscrowd = [int(o.get("iscrowd", 0)) for o in gt] + ious = maskUtils.iou(d, g, iscrowd) + return ious + + def computeOks(self, imgId, catId): + p = self.params + # dimension here should be Nxm + gts = self._gts[imgId, catId] + dts = self._dts[imgId, catId] + inds = np.argsort([-d["score"] for d in dts], kind="mergesort") + dts = [dts[i] for i in inds] + if len(dts) > p.maxDets[-1]: + dts = dts[0 : p.maxDets[-1]] + # if len(gts) == 0 and len(dts) == 0: + if len(gts) == 0 or len(dts) == 0: + return [] + ious = np.zeros((len(dts), len(gts))) + sigmas = ( + np.array( + [ + 0.26, + 0.25, + 0.25, + 0.35, + 0.35, + 0.79, + 0.79, + 0.72, + 0.72, + 0.62, + 0.62, + 1.07, + 1.07, + 0.87, + 0.87, + 0.89, + 0.89, + ] + ) + / 10.0 + ) + vars = (sigmas * 2) ** 2 + k = len(sigmas) + # compute oks between each detection and ground truth object + for j, gt in enumerate(gts): + # create bounds for ignore regions(double the gt bbox) + g = np.array(gt["keypoints"]) + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + k1 = np.count_nonzero(vg > 0) + bb = gt["bbox"] + x0 = bb[0] - bb[2] + x1 = bb[0] + bb[2] * 2 + y0 = bb[1] - bb[3] + y1 = bb[1] + bb[3] * 2 + for i, dt in enumerate(dts): + d = np.array(dt["keypoints"]) + xd = d[0::3] + yd = d[1::3] + if k1 > 0: + # measure the per-keypoint distance if keypoints visible + dx = xd - xg + dy = yd - yg + else: + # measure minimum distance to keypoints in (x0,y0) & (x1,y1) + z = np.zeros(k) + dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0) + dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0) + e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2 + if k1 > 0: + e = e[vg > 0] + ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] + return ious + + def _extract_mask(self, dt: Dict[str, Any]) -> np.ndarray: + if "densepose" in dt: + densepose_results_quantized = dt["densepose"] + return densepose_results_quantized.labels_uv_uint8[0].numpy() + elif "cse_mask" in dt: + return dt["cse_mask"] + elif "coarse_segm" in dt: + dy = max(int(dt["bbox"][3]), 1) + dx = max(int(dt["bbox"][2]), 1) + return ( + F.interpolate( + dt["coarse_segm"].unsqueeze(0), + (dy, dx), + mode="bilinear", + align_corners=False, + ) + .squeeze(0) + .argmax(0) + .numpy() + .astype(np.uint8) + ) + elif "record_id" in dt: + assert ( + self.multi_storage is not None + ), f"Storage record id encountered in a detection {dt}, but no storage provided!" + record = self.multi_storage.get(dt["rank"], dt["record_id"]) + coarse_segm = record["coarse_segm"] + dy = max(int(dt["bbox"][3]), 1) + dx = max(int(dt["bbox"][2]), 1) + return ( + F.interpolate( + coarse_segm.unsqueeze(0), + (dy, dx), + mode="bilinear", + align_corners=False, + ) + .squeeze(0) + .argmax(0) + .numpy() + .astype(np.uint8) + ) + else: + raise Exception(f"No mask data in the detection: {dt}") + raise ValueError('The prediction dict needs to contain either "densepose" or "cse_mask"') + + def _extract_iuv( + self, densepose_data: np.ndarray, py: np.ndarray, px: np.ndarray, gt: Dict[str, Any] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Extract arrays of I, U and V values at given points as numpy arrays + given the data mode stored in self._dpDataMode + """ + if self._dpDataMode == DensePoseDataMode.IUV_DT: + # estimated labels and UV (default) + ipoints = densepose_data[0, py, px] + upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255. + vpoints = densepose_data[2, py, px] / 255.0 + elif self._dpDataMode == DensePoseDataMode.IUV_GT: + # ground truth + ipoints = np.array(gt["dp_I"]) + upoints = np.array(gt["dp_U"]) + vpoints = np.array(gt["dp_V"]) + elif self._dpDataMode == DensePoseDataMode.I_GT_UV_0: + # ground truth labels, UV = 0 + ipoints = np.array(gt["dp_I"]) + upoints = upoints * 0.0 + vpoints = vpoints * 0.0 + elif self._dpDataMode == DensePoseDataMode.I_GT_UV_DT: + # ground truth labels, estimated UV + ipoints = np.array(gt["dp_I"]) + upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255. + vpoints = densepose_data[2, py, px] / 255.0 + elif self._dpDataMode == DensePoseDataMode.I_DT_UV_0: + # estimated labels, UV = 0 + ipoints = densepose_data[0, py, px] + upoints = upoints * 0.0 + vpoints = vpoints * 0.0 + else: + raise ValueError(f"Unknown data mode: {self._dpDataMode}") + return ipoints, upoints, vpoints + + def computeOgps_single_pair(self, dt, gt, py, px, pt_mask): + if "densepose" in dt: + ipoints, upoints, vpoints = self.extract_iuv_from_quantized(dt, gt, py, px, pt_mask) + return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints) + elif "u" in dt: + ipoints, upoints, vpoints = self.extract_iuv_from_raw(dt, gt, py, px, pt_mask) + return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints) + elif "record_id" in dt: + assert ( + self.multi_storage is not None + ), f"Storage record id encountered in detection {dt}, but no storage provided!" + record = self.multi_storage.get(dt["rank"], dt["record_id"]) + record["bbox"] = dt["bbox"] + if "u" in record: + ipoints, upoints, vpoints = self.extract_iuv_from_raw(record, gt, py, px, pt_mask) + return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints) + elif "embedding" in record: + return self.computeOgps_single_pair_cse( + dt, + gt, + py, + px, + pt_mask, + record["coarse_segm"], + record["embedding"], + record["bbox"], + ) + else: + raise Exception(f"Unknown record format: {record}") + elif "embedding" in dt: + return self.computeOgps_single_pair_cse( + dt, gt, py, px, pt_mask, dt["coarse_segm"], dt["embedding"], dt["bbox"] + ) + raise Exception(f"Unknown detection format: {dt}") + + def extract_iuv_from_quantized(self, dt, gt, py, px, pt_mask): + densepose_results_quantized = dt["densepose"] + ipoints, upoints, vpoints = self._extract_iuv( + densepose_results_quantized.labels_uv_uint8.numpy(), py, px, gt + ) + ipoints[pt_mask == -1] = 0 + return ipoints, upoints, vpoints + + def extract_iuv_from_raw(self, dt, gt, py, px, pt_mask): + labels_dt = resample_fine_and_coarse_segm_tensors_to_bbox( + dt["fine_segm"].unsqueeze(0), + dt["coarse_segm"].unsqueeze(0), + dt["bbox"], + ) + uv = resample_uv_tensors_to_bbox( + dt["u"].unsqueeze(0), dt["v"].unsqueeze(0), labels_dt.squeeze(0), dt["bbox"] + ) + labels_uv_uint8 = torch.cat((labels_dt.byte(), (uv * 255).clamp(0, 255).byte())) + ipoints, upoints, vpoints = self._extract_iuv(labels_uv_uint8.numpy(), py, px, gt) + ipoints[pt_mask == -1] = 0 + return ipoints, upoints, vpoints + + def computeOgps_single_pair_iuv(self, dt, gt, ipoints, upoints, vpoints): + cVertsGT, ClosestVertsGTTransformed = self.findAllClosestVertsGT(gt) + cVerts = self.findAllClosestVertsUV(upoints, vpoints, ipoints) + # Get pairwise geodesic distances between gt and estimated mesh points. + dist = self.getDistancesUV(ClosestVertsGTTransformed, cVerts) + # Compute the Ogps measure. + # Find the mean geodesic normalization distance for + # each GT point, based on which part it is on. + Current_Mean_Distances = self.Mean_Distances[ + self.CoarseParts[self.Part_ids[cVertsGT[cVertsGT > 0].astype(int) - 1]] + ] + return dist, Current_Mean_Distances + + def computeOgps_single_pair_cse( + self, dt, gt, py, px, pt_mask, coarse_segm, embedding, bbox_xywh_abs + ): + # 0-based mesh vertex indices + cVertsGT = torch.as_tensor(gt["dp_vertex"], dtype=torch.int64) + # label for each pixel of the bbox, [H, W] tensor of long + labels_dt = resample_coarse_segm_tensor_to_bbox( + coarse_segm.unsqueeze(0), bbox_xywh_abs + ).squeeze(0) + x, y, w, h = bbox_xywh_abs + # embedding for each pixel of the bbox, [D, H, W] tensor of float32 + embedding = F.interpolate( + embedding.unsqueeze(0), (int(h), int(w)), mode="bilinear", align_corners=False + ).squeeze(0) + # valid locations py, px + py_pt = torch.from_numpy(py[pt_mask > -1]) + px_pt = torch.from_numpy(px[pt_mask > -1]) + cVerts = torch.ones_like(cVertsGT) * -1 + cVerts[pt_mask > -1] = self.findClosestVertsCse( + embedding, py_pt, px_pt, labels_dt, gt["ref_model"] + ) + # Get pairwise geodesic distances between gt and estimated mesh points. + dist = self.getDistancesCse(cVertsGT, cVerts, gt["ref_model"]) + # normalize distances + if (gt["ref_model"] == "smpl_27554") and ("dp_I" in gt): + Current_Mean_Distances = self.Mean_Distances[ + self.CoarseParts[np.array(gt["dp_I"], dtype=int)] + ] + else: + Current_Mean_Distances = 0.255 + return dist, Current_Mean_Distances + + def computeOgps(self, imgId, catId): + p = self.params + # dimension here should be Nxm + g = self._gts[imgId, catId] + d = self._dts[imgId, catId] + inds = np.argsort([-d_["score"] for d_ in d], kind="mergesort") + d = [d[i] for i in inds] + if len(d) > p.maxDets[-1]: + d = d[0 : p.maxDets[-1]] + # if len(gts) == 0 and len(dts) == 0: + if len(g) == 0 or len(d) == 0: + return [] + ious = np.zeros((len(d), len(g))) + # compute opgs between each detection and ground truth object + # sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5 + # 1 # dist = 0.3m corresponds to ogps = 0.96 + # 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5) + for j, gt in enumerate(g): + if not gt["ignore"]: + g_ = gt["bbox"] + for i, dt in enumerate(d): + # + dy = int(dt["bbox"][3]) + dx = int(dt["bbox"][2]) + dp_x = np.array(gt["dp_x"]) * g_[2] / 255.0 + dp_y = np.array(gt["dp_y"]) * g_[3] / 255.0 + py = (dp_y + g_[1] - dt["bbox"][1]).astype(np.int) + px = (dp_x + g_[0] - dt["bbox"][0]).astype(np.int) + # + pts = np.zeros(len(px)) + pts[px >= dx] = -1 + pts[py >= dy] = -1 + pts[px < 0] = -1 + pts[py < 0] = -1 + if len(pts) < 1: + ogps = 0.0 + elif np.max(pts) == -1: + ogps = 0.0 + else: + px[pts == -1] = 0 + py[pts == -1] = 0 + dists_between_matches, dist_norm_coeffs = self.computeOgps_single_pair( + dt, gt, py, px, pts + ) + # Compute gps + ogps_values = np.exp( + -(dists_between_matches**2) / (2 * (dist_norm_coeffs**2)) + ) + # + ogps = np.mean(ogps_values) if len(ogps_values) > 0 else 0.0 + ious[i, j] = ogps + + gbb = [gt["bbox"] for gt in g] + dbb = [dt["bbox"] for dt in d] + + # compute iou between each dt and gt region + iscrowd = [int(o.get("iscrowd", 0)) for o in g] + ious_bb = maskUtils.iou(dbb, gbb, iscrowd) + return ious, ious_bb + + def evaluateImg(self, imgId, catId, aRng, maxDet): + """ + perform evaluation for single category and image + :return: dict (single image results) + """ + + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return None + + for g in gt: + # g['_ignore'] = g['ignore'] + if g["ignore"] or (g["area"] < aRng[0] or g["area"] > aRng[1]): + g["_ignore"] = True + else: + g["_ignore"] = False + + # sort dt highest score first, sort gt ignore last + gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort") + gt = [gt[i] for i in gtind] + dtind = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in dtind[0:maxDet]] + iscrowd = [int(o.get("iscrowd", 0)) for o in gt] + # load computed ious + if p.iouType == "densepose": + # print('Checking the length', len(self.ious[imgId, catId])) + # if len(self.ious[imgId, catId]) == 0: + # print(self.ious[imgId, catId]) + ious = ( + self.ious[imgId, catId][0][:, gtind] + if len(self.ious[imgId, catId]) > 0 + else self.ious[imgId, catId] + ) + ioubs = ( + self.ious[imgId, catId][1][:, gtind] + if len(self.ious[imgId, catId]) > 0 + else self.ious[imgId, catId] + ) + if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}: + iousM = ( + self.real_ious[imgId, catId][:, gtind] + if len(self.real_ious[imgId, catId]) > 0 + else self.real_ious[imgId, catId] + ) + else: + ious = ( + self.ious[imgId, catId][:, gtind] + if len(self.ious[imgId, catId]) > 0 + else self.ious[imgId, catId] + ) + + T = len(p.iouThrs) + G = len(gt) + D = len(dt) + gtm = np.zeros((T, G)) + dtm = np.zeros((T, D)) + gtIg = np.array([g["_ignore"] for g in gt]) + dtIg = np.zeros((T, D)) + if np.all(gtIg) and p.iouType == "densepose": + dtIg = np.logical_or(dtIg, True) + + if len(ious) > 0: # and not p.iouType == 'densepose': + for tind, t in enumerate(p.iouThrs): + for dind, d in enumerate(dt): + # information about best match so far (m=-1 -> unmatched) + iou = min([t, 1 - 1e-10]) + m = -1 + for gind, _g in enumerate(gt): + # if this gt already matched, and not a crowd, continue + if gtm[tind, gind] > 0 and not iscrowd[gind]: + continue + # if dt matched to reg gt, and on ignore gt, stop + if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1: + break + if p.iouType == "densepose": + if self._dpEvalMode == DensePoseEvalMode.GPSM: + new_iou = np.sqrt(iousM[dind, gind] * ious[dind, gind]) + elif self._dpEvalMode == DensePoseEvalMode.IOU: + new_iou = iousM[dind, gind] + elif self._dpEvalMode == DensePoseEvalMode.GPS: + new_iou = ious[dind, gind] + else: + new_iou = ious[dind, gind] + if new_iou < iou: + continue + if new_iou == 0.0: + continue + # if match successful and best so far, store appropriately + iou = new_iou + m = gind + # if match made store id of match for both dt and gt + if m == -1: + continue + dtIg[tind, dind] = gtIg[m] + dtm[tind, dind] = gt[m]["id"] + gtm[tind, m] = d["id"] + + if p.iouType == "densepose": + if not len(ioubs) == 0: + for dind, d in enumerate(dt): + # information about best match so far (m=-1 -> unmatched) + if dtm[tind, dind] == 0: + ioub = 0.8 + m = -1 + for gind, _g in enumerate(gt): + # if this gt already matched, and not a crowd, continue + if gtm[tind, gind] > 0 and not iscrowd[gind]: + continue + # continue to next gt unless better match made + if ioubs[dind, gind] < ioub: + continue + # if match successful and best so far, store appropriately + ioub = ioubs[dind, gind] + m = gind + # if match made store id of match for both dt and gt + if m > -1: + dtIg[:, dind] = gtIg[m] + if gtIg[m]: + dtm[tind, dind] = gt[m]["id"] + gtm[tind, m] = d["id"] + # set unmatched detections outside of area range to ignore + a = np.array([d["area"] < aRng[0] or d["area"] > aRng[1] for d in dt]).reshape((1, len(dt))) + dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0))) + # store results for given image and category + # print('Done with the function', len(self.ious[imgId, catId])) + return { + "image_id": imgId, + "category_id": catId, + "aRng": aRng, + "maxDet": maxDet, + "dtIds": [d["id"] for d in dt], + "gtIds": [g["id"] for g in gt], + "dtMatches": dtm, + "gtMatches": gtm, + "dtScores": [d["score"] for d in dt], + "gtIgnore": gtIg, + "dtIgnore": dtIg, + } + + def accumulate(self, p=None): + """ + Accumulate per image evaluation results and store the result in self.eval + :param p: input params for evaluation + :return: None + """ + logger.info("Accumulating evaluation results...") + tic = time.time() + if not self.evalImgs: + logger.info("Please run evaluate() first") + # allows input customized parameters + if p is None: + p = self.params + p.catIds = p.catIds if p.useCats == 1 else [-1] + T = len(p.iouThrs) + R = len(p.recThrs) + K = len(p.catIds) if p.useCats else 1 + A = len(p.areaRng) + M = len(p.maxDets) + precision = -(np.ones((T, R, K, A, M))) # -1 for the precision of absent categories + recall = -(np.ones((T, K, A, M))) + + # create dictionary for future indexing + logger.info("Categories: {}".format(p.catIds)) + _pe = self._paramsEval + catIds = _pe.catIds if _pe.useCats else [-1] + setK = set(catIds) + setA = set(map(tuple, _pe.areaRng)) + setM = set(_pe.maxDets) + setI = set(_pe.imgIds) + # get inds to evaluate + k_list = [n for n, k in enumerate(p.catIds) if k in setK] + m_list = [m for n, m in enumerate(p.maxDets) if m in setM] + a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA] + i_list = [n for n, i in enumerate(p.imgIds) if i in setI] + I0 = len(_pe.imgIds) + A0 = len(_pe.areaRng) + # retrieve E at each category, area range, and max number of detections + for k, k0 in enumerate(k_list): + Nk = k0 * A0 * I0 + for a, a0 in enumerate(a_list): + Na = a0 * I0 + for m, maxDet in enumerate(m_list): + E = [self.evalImgs[Nk + Na + i] for i in i_list] + E = [e for e in E if e is not None] + if len(E) == 0: + continue + dtScores = np.concatenate([e["dtScores"][0:maxDet] for e in E]) + + # different sorting method generates slightly different results. + # mergesort is used to be consistent as Matlab implementation. + inds = np.argsort(-dtScores, kind="mergesort") + + dtm = np.concatenate([e["dtMatches"][:, 0:maxDet] for e in E], axis=1)[:, inds] + dtIg = np.concatenate([e["dtIgnore"][:, 0:maxDet] for e in E], axis=1)[:, inds] + gtIg = np.concatenate([e["gtIgnore"] for e in E]) + npig = np.count_nonzero(gtIg == 0) + if npig == 0: + continue + tps = np.logical_and(dtm, np.logical_not(dtIg)) + fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg)) + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + tp = np.array(tp) + fp = np.array(fp) + nd = len(tp) + rc = tp / npig + pr = tp / (fp + tp + np.spacing(1)) + q = np.zeros((R,)) + + if nd: + recall[t, k, a, m] = rc[-1] + else: + recall[t, k, a, m] = 0 + + # numpy is slow without cython optimization for accessing elements + # use python array gets significant speed improvement + pr = pr.tolist() + q = q.tolist() + + for i in range(nd - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds = np.searchsorted(rc, p.recThrs, side="left") + try: + for ri, pi in enumerate(inds): + q[ri] = pr[pi] + except Exception: + pass + precision[t, :, k, a, m] = np.array(q) + logger.info( + "Final: max precision {}, min precision {}".format(np.max(precision), np.min(precision)) + ) + self.eval = { + "params": p, + "counts": [T, R, K, A, M], + "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "precision": precision, + "recall": recall, + } + toc = time.time() + logger.info("DONE (t={:0.2f}s).".format(toc - tic)) + + def summarize(self): + """ + Compute and display summary metrics for evaluation results. + Note this function can *only* be applied on the default parameter setting + """ + + def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100): + p = self.params + iStr = " {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + measure = "IoU" + if self.params.iouType == "keypoints": + measure = "OKS" + elif self.params.iouType == "densepose": + measure = "OGPS" + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval["precision"] + # IoU + if iouThr is not None: + t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval["recall"] + if iouThr is not None: + t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + logger.info(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s)) + return mean_s + + def _summarizeDets(): + stats = np.zeros((12,)) + stats[0] = _summarize(1) + stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2]) + stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2]) + stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2]) + stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2]) + stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) + stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) + stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) + stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2]) + stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2]) + return stats + + def _summarizeKps(): + stats = np.zeros((10,)) + stats[0] = _summarize(1, maxDets=20) + stats[1] = _summarize(1, maxDets=20, iouThr=0.5) + stats[2] = _summarize(1, maxDets=20, iouThr=0.75) + stats[3] = _summarize(1, maxDets=20, areaRng="medium") + stats[4] = _summarize(1, maxDets=20, areaRng="large") + stats[5] = _summarize(0, maxDets=20) + stats[6] = _summarize(0, maxDets=20, iouThr=0.5) + stats[7] = _summarize(0, maxDets=20, iouThr=0.75) + stats[8] = _summarize(0, maxDets=20, areaRng="medium") + stats[9] = _summarize(0, maxDets=20, areaRng="large") + return stats + + def _summarizeUvs(): + stats = [_summarize(1, maxDets=self.params.maxDets[0])] + min_threshold = self.params.iouThrs.min() + if min_threshold <= 0.201: + stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.2)] + if min_threshold <= 0.301: + stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.3)] + if min_threshold <= 0.401: + stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.4)] + stats += [ + _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5), + _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75), + _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium"), + _summarize(1, maxDets=self.params.maxDets[0], areaRng="large"), + _summarize(0, maxDets=self.params.maxDets[0]), + _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5), + _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75), + _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium"), + _summarize(0, maxDets=self.params.maxDets[0], areaRng="large"), + ] + return np.array(stats) + + def _summarizeUvsOld(): + stats = np.zeros((18,)) + stats[0] = _summarize(1, maxDets=self.params.maxDets[0]) + stats[1] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5) + stats[2] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.55) + stats[3] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.60) + stats[4] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.65) + stats[5] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.70) + stats[6] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75) + stats[7] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.80) + stats[8] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.85) + stats[9] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.90) + stats[10] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.95) + stats[11] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium") + stats[12] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="large") + stats[13] = _summarize(0, maxDets=self.params.maxDets[0]) + stats[14] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5) + stats[15] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75) + stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium") + stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="large") + return stats + + if not self.eval: + raise Exception("Please run accumulate() first") + iouType = self.params.iouType + if iouType in ["segm", "bbox"]: + summarize = _summarizeDets + elif iouType in ["keypoints"]: + summarize = _summarizeKps + elif iouType in ["densepose"]: + summarize = _summarizeUvs + self.stats = summarize() + + def __str__(self): + self.summarize() + + # ================ functions for dense pose ============================== + def findAllClosestVertsUV(self, U_points, V_points, Index_points): + ClosestVerts = np.ones(Index_points.shape) * -1 + for i in np.arange(24): + # + if (i + 1) in Index_points: + UVs = np.array( + [U_points[Index_points == (i + 1)], V_points[Index_points == (i + 1)]] + ) + Current_Part_UVs = self.Part_UVs[i] + Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i] + D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze() + ClosestVerts[Index_points == (i + 1)] = Current_Part_ClosestVertInds[ + np.argmin(D, axis=0) + ] + ClosestVertsTransformed = self.PDIST_transform[ClosestVerts.astype(int) - 1] + ClosestVertsTransformed[ClosestVerts < 0] = 0 + return ClosestVertsTransformed + + def findClosestVertsCse(self, embedding, py, px, mask, mesh_name): + mesh_vertex_embeddings = self.embedder(mesh_name) + pixel_embeddings = embedding[:, py, px].t().to(device="cuda") + mask_vals = mask[py, px] + edm = squared_euclidean_distance_matrix(pixel_embeddings, mesh_vertex_embeddings) + vertex_indices = edm.argmin(dim=1).cpu() + vertex_indices[mask_vals <= 0] = -1 + return vertex_indices + + def findAllClosestVertsGT(self, gt): + # + I_gt = np.array(gt["dp_I"]) + U_gt = np.array(gt["dp_U"]) + V_gt = np.array(gt["dp_V"]) + # + # print(I_gt) + # + ClosestVertsGT = np.ones(I_gt.shape) * -1 + for i in np.arange(24): + if (i + 1) in I_gt: + UVs = np.array([U_gt[I_gt == (i + 1)], V_gt[I_gt == (i + 1)]]) + Current_Part_UVs = self.Part_UVs[i] + Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i] + D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze() + ClosestVertsGT[I_gt == (i + 1)] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)] + # + ClosestVertsGTTransformed = self.PDIST_transform[ClosestVertsGT.astype(int) - 1] + ClosestVertsGTTransformed[ClosestVertsGT < 0] = 0 + return ClosestVertsGT, ClosestVertsGTTransformed + + def getDistancesCse(self, cVertsGT, cVerts, mesh_name): + geodists_vertices = torch.ones_like(cVertsGT) * float("inf") + selected = (cVertsGT >= 0) * (cVerts >= 0) + mesh = create_mesh(mesh_name, "cpu") + geodists_vertices[selected] = mesh.geodists[cVertsGT[selected], cVerts[selected]] + return geodists_vertices.numpy() + + def getDistancesUV(self, cVertsGT, cVerts): + # + n = 27554 + dists = [] + for d in range(len(cVertsGT)): + if cVertsGT[d] > 0: + if cVerts[d] > 0: + i = cVertsGT[d] - 1 + j = cVerts[d] - 1 + if j == i: + dists.append(0) + elif j > i: + ccc = i + i = j + j = ccc + i = n - i - 1 + j = n - j - 1 + k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1 + k = (n * n - n) / 2 - k - 1 + dists.append(self.Pdist_matrix[int(k)][0]) + else: + i = n - i - 1 + j = n - j - 1 + k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1 + k = (n * n - n) / 2 - k - 1 + dists.append(self.Pdist_matrix[int(k)][0]) + else: + dists.append(np.inf) + return np.atleast_1d(np.array(dists).squeeze()) + + +class Params: + """ + Params for coco evaluation api + """ + + def setDetParams(self): + self.imgIds = [] + self.catIds = [] + # np.arange causes trouble. the data point on arange is slightly larger than the true value + self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True) + self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True) + self.maxDets = [1, 10, 100] + self.areaRng = [ + [0**2, 1e5**2], + [0**2, 32**2], + [32**2, 96**2], + [96**2, 1e5**2], + ] + self.areaRngLbl = ["all", "small", "medium", "large"] + self.useCats = 1 + + def setKpParams(self): + self.imgIds = [] + self.catIds = [] + # np.arange causes trouble. the data point on arange is slightly larger than the true value + self.iouThrs = np.linspace(0.5, 0.95, np.round((0.95 - 0.5) / 0.05) + 1, endpoint=True) + self.recThrs = np.linspace(0.0, 1.00, np.round((1.00 - 0.0) / 0.01) + 1, endpoint=True) + self.maxDets = [20] + self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]] + self.areaRngLbl = ["all", "medium", "large"] + self.useCats = 1 + + def setUvParams(self): + self.imgIds = [] + self.catIds = [] + self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True) + self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True) + self.maxDets = [20] + self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]] + self.areaRngLbl = ["all", "medium", "large"] + self.useCats = 1 + + def __init__(self, iouType="segm"): + if iouType == "segm" or iouType == "bbox": + self.setDetParams() + elif iouType == "keypoints": + self.setKpParams() + elif iouType == "densepose": + self.setUvParams() + else: + raise Exception("iouType not supported") + self.iouType = iouType + # useSegm is deprecated + self.useSegm = None diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/evaluator.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/evaluator.py new file mode 100644 index 0000000..d5d1d78 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/evaluator.py @@ -0,0 +1,421 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +import copy +import io +import itertools +import logging +import numpy as np +import os +from collections import OrderedDict +from typing import Dict, Iterable, List, Optional +import pycocotools.mask as mask_utils +import torch +from pycocotools.coco import COCO +from tabulate import tabulate + +from detectron2.config import CfgNode +from detectron2.data import MetadataCatalog +from detectron2.evaluation import DatasetEvaluator +from detectron2.structures import BoxMode +from detectron2.utils.comm import gather, get_rank, is_main_process, synchronize +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import create_small_table + +from densepose.converters import ToChartResultConverter, ToMaskConverter +from densepose.data.datasets.coco import maybe_filter_and_map_categories_cocoapi +from densepose.structures import ( + DensePoseChartPredictorOutput, + DensePoseEmbeddingPredictorOutput, + quantize_densepose_chart_result, +) + +from .densepose_coco_evaluation import DensePoseCocoEval, DensePoseEvalMode +from .mesh_alignment_evaluator import MeshAlignmentEvaluator +from .tensor_storage import ( + SingleProcessFileTensorStorage, + SingleProcessRamTensorStorage, + SingleProcessTensorStorage, + SizeData, + storage_gather, +) + + +class DensePoseCOCOEvaluator(DatasetEvaluator): + def __init__( + self, + dataset_name, + distributed, + output_dir=None, + evaluator_type: str = "iuv", + min_iou_threshold: float = 0.5, + storage: Optional[SingleProcessTensorStorage] = None, + embedder=None, + should_evaluate_mesh_alignment: bool = False, + mesh_alignment_mesh_names: Optional[List[str]] = None, + ): + self._embedder = embedder + self._distributed = distributed + self._output_dir = output_dir + self._evaluator_type = evaluator_type + self._storage = storage + self._should_evaluate_mesh_alignment = should_evaluate_mesh_alignment + + assert not ( + should_evaluate_mesh_alignment and embedder is None + ), "Mesh alignment evaluation is activated, but no vertex embedder provided!" + if should_evaluate_mesh_alignment: + self._mesh_alignment_evaluator = MeshAlignmentEvaluator( + embedder, + mesh_alignment_mesh_names, + ) + + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + self._metadata = MetadataCatalog.get(dataset_name) + self._min_threshold = min_iou_threshold + json_file = PathManager.get_local_path(self._metadata.json_file) + with contextlib.redirect_stdout(io.StringIO()): + self._coco_api = COCO(json_file) + maybe_filter_and_map_categories_cocoapi(dataset_name, self._coco_api) + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + The :class:`Instances` object needs to have `densepose` field. + """ + for input, output in zip(inputs, outputs): + instances = output["instances"].to(self._cpu_device) + if not instances.has("pred_densepose"): + continue + prediction_list = prediction_to_dict( + instances, + input["image_id"], + self._embedder, + self._metadata.class_to_mesh_name, + self._storage is not None, + ) + if self._storage is not None: + for prediction_dict in prediction_list: + dict_to_store = {} + for field_name in self._storage.data_schema: + dict_to_store[field_name] = prediction_dict[field_name] + record_id = self._storage.put(dict_to_store) + prediction_dict["record_id"] = record_id + prediction_dict["rank"] = get_rank() + for field_name in self._storage.data_schema: + del prediction_dict[field_name] + self._predictions.extend(prediction_list) + + def evaluate(self, img_ids=None): + if self._distributed: + synchronize() + predictions = gather(self._predictions) + predictions = list(itertools.chain(*predictions)) + else: + predictions = self._predictions + + multi_storage = storage_gather(self._storage) if self._storage is not None else None + + if not is_main_process(): + return + return copy.deepcopy(self._eval_predictions(predictions, multi_storage, img_ids)) + + def _eval_predictions(self, predictions, multi_storage=None, img_ids=None): + """ + Evaluate predictions on densepose. + Return results with the metrics of the tasks. + """ + self._logger.info("Preparing results for COCO format ...") + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "coco_densepose_predictions.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(predictions, f) + + self._logger.info("Evaluating predictions ...") + res = OrderedDict() + results_gps, results_gpsm, results_segm = _evaluate_predictions_on_coco( + self._coco_api, + predictions, + multi_storage, + self._embedder, + class_names=self._metadata.get("thing_classes"), + min_threshold=self._min_threshold, + img_ids=img_ids, + ) + res["densepose_gps"] = results_gps + res["densepose_gpsm"] = results_gpsm + res["densepose_segm"] = results_segm + if self._should_evaluate_mesh_alignment: + res["densepose_mesh_alignment"] = self._evaluate_mesh_alignment() + return res + + def _evaluate_mesh_alignment(self): + self._logger.info("Mesh alignment evaluation ...") + mean_ge, mean_gps, per_mesh_metrics = self._mesh_alignment_evaluator.evaluate() + results = { + "GE": mean_ge * 100, + "GPS": mean_gps * 100, + } + mesh_names = set() + for metric_name in per_mesh_metrics: + for mesh_name, value in per_mesh_metrics[metric_name].items(): + results[f"{metric_name}-{mesh_name}"] = value * 100 + mesh_names.add(mesh_name) + self._print_mesh_alignment_results(results, mesh_names) + return results + + def _print_mesh_alignment_results(self, results: Dict[str, float], mesh_names: Iterable[str]): + self._logger.info("Evaluation results for densepose, mesh alignment:") + self._logger.info(f'| {"Mesh":13s} | {"GErr":7s} | {"GPS":7s} |') + self._logger.info("| :-----------: | :-----: | :-----: |") + for mesh_name in mesh_names: + ge_key = f"GE-{mesh_name}" + ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " " + gps_key = f"GPS-{mesh_name}" + gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " " + self._logger.info(f"| {mesh_name:13s} | {ge_str:7s} | {gps_str:7s} |") + self._logger.info("| :-------------------------------: |") + ge_key = "GE" + ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " " + gps_key = "GPS" + gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " " + self._logger.info(f'| {"MEAN":13s} | {ge_str:7s} | {gps_str:7s} |') + + +def prediction_to_dict(instances, img_id, embedder, class_to_mesh_name, use_storage): + """ + Args: + instances (Instances): the output of the model + img_id (str): the image id in COCO + + Returns: + list[dict]: the results in densepose evaluation format + """ + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + raw_boxes_xywh = BoxMode.convert( + instances.pred_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + + if isinstance(instances.pred_densepose, DensePoseEmbeddingPredictorOutput): + results_densepose = densepose_cse_predictions_to_dict( + instances, embedder, class_to_mesh_name, use_storage + ) + elif isinstance(instances.pred_densepose, DensePoseChartPredictorOutput): + if not use_storage: + results_densepose = densepose_chart_predictions_to_dict(instances) + else: + results_densepose = densepose_chart_predictions_to_storage_dict(instances) + + results = [] + for k in range(len(instances)): + result = { + "image_id": img_id, + "category_id": classes[k], + "bbox": raw_boxes_xywh[k].tolist(), + "score": scores[k], + } + results.append({**result, **results_densepose[k]}) + return results + + +def densepose_chart_predictions_to_dict(instances): + segmentations = ToMaskConverter.convert( + instances.pred_densepose, instances.pred_boxes, instances.image_size + ) + + results = [] + for k in range(len(instances)): + densepose_results_quantized = quantize_densepose_chart_result( + ToChartResultConverter.convert(instances.pred_densepose[k], instances.pred_boxes[k]) + ) + densepose_results_quantized.labels_uv_uint8 = ( + densepose_results_quantized.labels_uv_uint8.cpu() + ) + segmentation = segmentations.tensor[k] + segmentation_encoded = mask_utils.encode( + np.require(segmentation.numpy(), dtype=np.uint8, requirements=["F"]) + ) + segmentation_encoded["counts"] = segmentation_encoded["counts"].decode("utf-8") + result = { + "densepose": densepose_results_quantized, + "segmentation": segmentation_encoded, + } + results.append(result) + return results + + +def densepose_chart_predictions_to_storage_dict(instances): + results = [] + for k in range(len(instances)): + densepose_predictor_output = instances.pred_densepose[k] + result = { + "coarse_segm": densepose_predictor_output.coarse_segm.squeeze(0).cpu(), + "fine_segm": densepose_predictor_output.fine_segm.squeeze(0).cpu(), + "u": densepose_predictor_output.u.squeeze(0).cpu(), + "v": densepose_predictor_output.v.squeeze(0).cpu(), + } + results.append(result) + return results + + +def densepose_cse_predictions_to_dict(instances, embedder, class_to_mesh_name, use_storage): + results = [] + for k in range(len(instances)): + cse = instances.pred_densepose[k] + results.append( + { + "coarse_segm": cse.coarse_segm[0].cpu(), + "embedding": cse.embedding[0].cpu(), + } + ) + return results + + +def _evaluate_predictions_on_coco( + coco_gt, + coco_results, + multi_storage=None, + embedder=None, + class_names=None, + min_threshold: float = 0.5, + img_ids=None, +): + logger = logging.getLogger(__name__) + + densepose_metrics = _get_densepose_metrics(min_threshold) + if len(coco_results) == 0: # cocoapi does not handle empty results very well + logger.warn("No predictions from the model! Set scores to -1") + results_gps = {metric: -1 for metric in densepose_metrics} + results_gpsm = {metric: -1 for metric in densepose_metrics} + results_segm = {metric: -1 for metric in densepose_metrics} + return results_gps, results_gpsm, results_segm + + coco_dt = coco_gt.loadRes(coco_results) + + results = [] + for eval_mode_name in ["GPS", "GPSM", "IOU"]: + eval_mode = getattr(DensePoseEvalMode, eval_mode_name) + coco_eval = DensePoseCocoEval( + coco_gt, coco_dt, "densepose", multi_storage, embedder, dpEvalMode=eval_mode + ) + result = _derive_results_from_coco_eval( + coco_eval, eval_mode_name, densepose_metrics, class_names, min_threshold, img_ids + ) + results.append(result) + return results + + +def _get_densepose_metrics(min_threshold: float = 0.5): + metrics = ["AP"] + if min_threshold <= 0.201: + metrics += ["AP20"] + if min_threshold <= 0.301: + metrics += ["AP30"] + if min_threshold <= 0.401: + metrics += ["AP40"] + metrics.extend(["AP50", "AP75", "APm", "APl", "AR", "AR50", "AR75", "ARm", "ARl"]) + return metrics + + +def _derive_results_from_coco_eval( + coco_eval, eval_mode_name, metrics, class_names, min_threshold: float, img_ids +): + if img_ids is not None: + coco_eval.params.imgIds = img_ids + coco_eval.params.iouThrs = np.linspace( + min_threshold, 0.95, int(np.round((0.95 - min_threshold) / 0.05)) + 1, endpoint=True + ) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + results = {metric: float(coco_eval.stats[idx] * 100) for idx, metric in enumerate(metrics)} + logger = logging.getLogger(__name__) + logger.info( + f"Evaluation results for densepose, {eval_mode_name} metric: \n" + + create_small_table(results) + ) + if class_names is None or len(class_names) <= 1: + return results + + # Compute per-category AP, the same way as it is done in D2 + # (see detectron2/evaluation/coco_evaluation.py): + precisions = coco_eval.eval["precision"] + # precision has dims (iou, recall, cls, area range, max dets) + assert len(class_names) == precisions.shape[2] + + results_per_category = [] + for idx, name in enumerate(class_names): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float("nan") + results_per_category.append((f"{name}", float(ap * 100))) + + # tabulate it + n_cols = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + results_2d = itertools.zip_longest(*[results_flatten[i::n_cols] for i in range(n_cols)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP"] * (n_cols // 2), + numalign="left", + ) + logger.info(f"Per-category {eval_mode_name} AP: \n" + table) + + results.update({"AP-" + name: ap for name, ap in results_per_category}) + return results + + +def build_densepose_evaluator_storage(cfg: CfgNode, output_folder: str): + storage_spec = cfg.DENSEPOSE_EVALUATION.STORAGE + if storage_spec == "none": + return None + evaluator_type = cfg.DENSEPOSE_EVALUATION.TYPE + # common output tensor sizes + hout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE + wout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE + n_csc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + # specific output tensors + if evaluator_type == "iuv": + n_fsc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1 + schema = { + "coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)), + "fine_segm": SizeData(dtype="float32", shape=(n_fsc, hout, wout)), + "u": SizeData(dtype="float32", shape=(n_fsc, hout, wout)), + "v": SizeData(dtype="float32", shape=(n_fsc, hout, wout)), + } + elif evaluator_type == "cse": + embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE + schema = { + "coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)), + "embedding": SizeData(dtype="float32", shape=(embed_size, hout, wout)), + } + else: + raise ValueError(f"Unknown evaluator type: {evaluator_type}") + # storage types + if storage_spec == "ram": + storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) + elif storage_spec == "file": + fpath = os.path.join(output_folder, f"DensePoseEvaluatorStorage.{get_rank()}.bin") + PathManager.mkdirs(output_folder) + storage = SingleProcessFileTensorStorage(schema, fpath, "wb") + else: + raise ValueError(f"Unknown storage specification: {storage_spec}") + return storage diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/mesh_alignment_evaluator.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/mesh_alignment_evaluator.py new file mode 100644 index 0000000..9d67c1a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/mesh_alignment_evaluator.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import json +import logging +from typing import List, Optional +import torch +from torch import nn + +from detectron2.utils.file_io import PathManager + +from densepose.structures.mesh import create_mesh + + +class MeshAlignmentEvaluator: + """ + Class for evaluation of 3D mesh alignment based on the learned vertex embeddings + """ + + def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): + self.embedder = embedder + # use the provided mesh names if not None and not an empty list + self.mesh_names = mesh_names if mesh_names else embedder.mesh_names + self.logger = logging.getLogger(__name__) + with PathManager.open( + "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" + ) as f: + self.mesh_keyvertices = json.load(f) + + def evaluate(self): + ge_per_mesh = {} + gps_per_mesh = {} + for mesh_name_1 in self.mesh_names: + avg_errors = [] + avg_gps = [] + embeddings_1 = self.embedder(mesh_name_1) + keyvertices_1 = self.mesh_keyvertices[mesh_name_1] + keyvertex_names_1 = list(keyvertices_1.keys()) + keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] + for mesh_name_2 in self.mesh_names: + if mesh_name_1 == mesh_name_2: + continue + embeddings_2 = self.embedder(mesh_name_2) + keyvertices_2 = self.mesh_keyvertices[mesh_name_2] + sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) + vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) + mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) + geodists = mesh_2.geodists[ + vertices_2_matching_keyvertices_1, + [keyvertices_2[name] for name in keyvertex_names_1], + ] + Current_Mean_Distances = 0.255 + gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() + avg_errors.append(geodists.mean().item()) + avg_gps.append(gps.mean().item()) + + ge_mean = torch.as_tensor(avg_errors).mean().item() + gps_mean = torch.as_tensor(avg_gps).mean().item() + ge_per_mesh[mesh_name_1] = ge_mean + gps_per_mesh[mesh_name_1] = gps_mean + ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() + gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() + per_mesh_metrics = { + "GE": ge_per_mesh, + "GPS": gps_per_mesh, + } + return ge_mean_global, gps_mean_global, per_mesh_metrics diff --git a/data_processing/detectron2/projects/DensePose/densepose/evaluation/tensor_storage.py b/data_processing/detectron2/projects/DensePose/densepose/evaluation/tensor_storage.py new file mode 100644 index 0000000..72e3cb6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/evaluation/tensor_storage.py @@ -0,0 +1,238 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import io +import numpy as np +import os +from dataclasses import dataclass +from functools import reduce +from operator import mul +from typing import BinaryIO, Dict, Optional, Tuple +import torch + +from detectron2.utils.comm import gather, get_rank +from detectron2.utils.file_io import PathManager + + +@dataclass +class SizeData: + dtype: str + shape: Tuple[int] + + +def _calculate_record_field_size_b(data_schema: Dict[str, SizeData], field_name: str) -> int: + schema = data_schema[field_name] + element_size_b = np.dtype(schema.dtype).itemsize + record_field_size_b = reduce(mul, schema.shape) * element_size_b + return record_field_size_b + + +def _calculate_record_size_b(data_schema: Dict[str, SizeData]) -> int: + record_size_b = 0 + for field_name in data_schema: + record_field_size_b = _calculate_record_field_size_b(data_schema, field_name) + record_size_b += record_field_size_b + return record_size_b + + +def _calculate_record_field_sizes_b(data_schema: Dict[str, SizeData]) -> Dict[str, int]: + field_sizes_b = {} + for field_name in data_schema: + field_sizes_b[field_name] = _calculate_record_field_size_b(data_schema, field_name) + return field_sizes_b + + +class SingleProcessTensorStorage: + """ + Compact tensor storage to keep tensor data of predefined size and type. + """ + + def __init__(self, data_schema: Dict[str, SizeData], storage_impl: BinaryIO): + """ + Construct tensor storage based on information on data shape and size. + Internally uses numpy to interpret the type specification. + The storage must support operations `seek(offset, whence=os.SEEK_SET)` and + `read(size)` to be able to perform the `get` operation. + The storage must support operation `write(bytes)` to be able to perform + the `put` operation. + + Args: + data_schema (dict: str -> SizeData): dictionary which maps tensor name + to its size data (shape and data type), e.g. + ``` + { + "coarse_segm": SizeData(dtype="float32", shape=(112, 112)), + "embedding": SizeData(dtype="float32", shape=(16, 112, 112)), + } + ``` + storage_impl (BinaryIO): io instance that handles file-like seek, read + and write operations, e.g. a file handle or a memory buffer like io.BytesIO + """ + self.data_schema = data_schema + self.record_size_b = _calculate_record_size_b(data_schema) + self.record_field_sizes_b = _calculate_record_field_sizes_b(data_schema) + self.storage_impl = storage_impl + self.next_record_id = 0 + + def get(self, record_id: int) -> Dict[str, torch.Tensor]: + """ + Load tensors from the storage by record ID + + Args: + record_id (int): Record ID, for which to load the data + + Return: + dict: str -> tensor: tensor name mapped to tensor data, recorded under the provided ID + """ + self.storage_impl.seek(record_id * self.record_size_b, os.SEEK_SET) + data_bytes = self.storage_impl.read(self.record_size_b) + assert len(data_bytes) == self.record_size_b, ( + f"Expected data size {self.record_size_b} B could not be read: " + f"got {len(data_bytes)} B" + ) + record = {} + cur_idx = 0 + # it's important to read and write in the same order + for field_name in sorted(self.data_schema): + schema = self.data_schema[field_name] + field_size_b = self.record_field_sizes_b[field_name] + chunk = data_bytes[cur_idx : cur_idx + field_size_b] + data_np = np.frombuffer( + chunk, dtype=schema.dtype, count=reduce(mul, schema.shape) + ).reshape(schema.shape) + record[field_name] = torch.from_numpy(data_np) + cur_idx += field_size_b + return record + + def put(self, data: Dict[str, torch.Tensor]) -> int: + """ + Store tensors in the storage + + Args: + data (dict: str -> tensor): data to store, a dictionary which maps + tensor names into tensors; tensor shapes must match those specified + in data schema. + Return: + int: record ID, under which the data is stored + """ + # it's important to read and write in the same order + for field_name in sorted(self.data_schema): + assert ( + field_name in data + ), f"Field '{field_name}' not present in data: data keys are {data.keys()}" + value = data[field_name] + assert value.shape == self.data_schema[field_name].shape, ( + f"Mismatched tensor shapes for field '{field_name}': " + f"expected {self.data_schema[field_name].shape}, got {value.shape}" + ) + data_bytes = value.cpu().numpy().tobytes() + assert len(data_bytes) == self.record_field_sizes_b[field_name], ( + f"Expected field {field_name} to be of size " + f"{self.record_field_sizes_b[field_name]} B, got {len(data_bytes)} B" + ) + self.storage_impl.write(data_bytes) + record_id = self.next_record_id + self.next_record_id += 1 + return record_id + + +class SingleProcessFileTensorStorage(SingleProcessTensorStorage): + """ + Implementation of a single process tensor storage which stores data in a file + """ + + def __init__(self, data_schema: Dict[str, SizeData], fpath: str, mode: str): + self.fpath = fpath + assert "b" in mode, f"Tensor storage should be opened in binary mode, got '{mode}'" + if "w" in mode: + file_h = PathManager.open(fpath, mode) + elif "r" in mode: + local_fpath = PathManager.get_local_path(fpath) + file_h = open(local_fpath, mode) + else: + raise ValueError(f"Unsupported file mode {mode}, supported modes: rb, wb") + super().__init__(data_schema, file_h) # pyre-ignore[6] + + +class SingleProcessRamTensorStorage(SingleProcessTensorStorage): + """ + Implementation of a single process tensor storage which stores data in RAM + """ + + def __init__(self, data_schema: Dict[str, SizeData], buf: io.BytesIO): + super().__init__(data_schema, buf) + + +class MultiProcessTensorStorage: + """ + Representation of a set of tensor storages created by individual processes, + allows to access those storages from a single owner process. The storages + should either be shared or broadcasted to the owner process. + The processes are identified by their rank, data is uniquely defined by + the rank of the process and the record ID. + """ + + def __init__(self, rank_to_storage: Dict[int, SingleProcessTensorStorage]): + self.rank_to_storage = rank_to_storage + + def get(self, rank: int, record_id: int) -> Dict[str, torch.Tensor]: + storage = self.rank_to_storage[rank] + return storage.get(record_id) + + def put(self, rank: int, data: Dict[str, torch.Tensor]) -> int: + storage = self.rank_to_storage[rank] + return storage.put(data) + + +class MultiProcessFileTensorStorage(MultiProcessTensorStorage): + def __init__(self, data_schema: Dict[str, SizeData], rank_to_fpath: Dict[int, str], mode: str): + rank_to_storage = { + rank: SingleProcessFileTensorStorage(data_schema, fpath, mode) + for rank, fpath in rank_to_fpath.items() + } + super().__init__(rank_to_storage) # pyre-ignore[6] + + +class MultiProcessRamTensorStorage(MultiProcessTensorStorage): + def __init__(self, data_schema: Dict[str, SizeData], rank_to_buffer: Dict[int, io.BytesIO]): + rank_to_storage = { + rank: SingleProcessRamTensorStorage(data_schema, buf) + for rank, buf in rank_to_buffer.items() + } + super().__init__(rank_to_storage) # pyre-ignore[6] + + +def _ram_storage_gather( + storage: SingleProcessRamTensorStorage, dst_rank: int = 0 +) -> Optional[MultiProcessRamTensorStorage]: + storage.storage_impl.seek(0, os.SEEK_SET) + # TODO: overhead, pickling a bytes object, can just pass bytes in a tensor directly + # see detectron2/utils.comm.py + data_list = gather(storage.storage_impl.read(), dst=dst_rank) + if get_rank() != dst_rank: + return None + rank_to_buffer = {i: io.BytesIO(data_list[i]) for i in range(len(data_list))} + multiprocess_storage = MultiProcessRamTensorStorage(storage.data_schema, rank_to_buffer) + return multiprocess_storage + + +def _file_storage_gather( + storage: SingleProcessFileTensorStorage, + dst_rank: int = 0, + mode: str = "rb", +) -> Optional[MultiProcessFileTensorStorage]: + storage.storage_impl.close() + fpath_list = gather(storage.fpath, dst=dst_rank) + if get_rank() != dst_rank: + return None + rank_to_fpath = {i: fpath_list[i] for i in range(len(fpath_list))} + return MultiProcessFileTensorStorage(storage.data_schema, rank_to_fpath, mode) + + +def storage_gather( + storage: SingleProcessTensorStorage, dst_rank: int = 0 +) -> Optional[MultiProcessTensorStorage]: + if isinstance(storage, SingleProcessRamTensorStorage): + return _ram_storage_gather(storage, dst_rank) + elif isinstance(storage, SingleProcessFileTensorStorage): + return _file_storage_gather(storage, dst_rank) + raise Exception(f"Unsupported storage for gather operation: {storage}") diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/__init__.py new file mode 100644 index 0000000..4c49f6d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .confidence import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType +from .filter import DensePoseDataFilter +from .inference import densepose_inference +from .utils import initialize_module_params +from .build import ( + build_densepose_data_filter, + build_densepose_embedder, + build_densepose_head, + build_densepose_losses, + build_densepose_predictor, +) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/build.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/build.py new file mode 100644 index 0000000..bb7f54b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/build.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Optional +from torch import nn + +from detectron2.config import CfgNode + +from .cse.embedder import Embedder +from .filter import DensePoseDataFilter + + +def build_densepose_predictor(cfg: CfgNode, input_channels: int): + """ + Create an instance of DensePose predictor based on configuration options. + + Args: + cfg (CfgNode): configuration options + input_channels (int): input tensor size along the channel dimension + Return: + An instance of DensePose predictor + """ + from .predictors import DENSEPOSE_PREDICTOR_REGISTRY + + predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME + return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels) + + +def build_densepose_data_filter(cfg: CfgNode): + """ + Build DensePose data filter which selects data for training + + Args: + cfg (CfgNode): configuration options + + Return: + Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances) + An instance of DensePose filter, which takes feature tensors and proposals + as an input and returns filtered features and proposals + """ + dp_filter = DensePoseDataFilter(cfg) + return dp_filter + + +def build_densepose_head(cfg: CfgNode, input_channels: int): + """ + Build DensePose head based on configurations options + + Args: + cfg (CfgNode): configuration options + input_channels (int): input tensor size along the channel dimension + Return: + An instance of DensePose head + """ + from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY + + head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME + return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels) + + +def build_densepose_losses(cfg: CfgNode): + """ + Build DensePose loss based on configurations options + + Args: + cfg (CfgNode): configuration options + Return: + An instance of DensePose loss + """ + from .losses import DENSEPOSE_LOSS_REGISTRY + + loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME + return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg) + + +def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]: + """ + Build embedder used to embed mesh vertices into an embedding space. + Embedder contains sub-embedders, one for each mesh ID. + + Args: + cfg (cfgNode): configuration options + Return: + Embedding module + """ + if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS: + return Embedder(cfg) + return None diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/confidence.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/confidence.py new file mode 100644 index 0000000..6f4a72e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/confidence.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import dataclass +from enum import Enum + +from detectron2.config import CfgNode + + +class DensePoseUVConfidenceType(Enum): + """ + Statistical model type for confidence learning, possible values: + - "iid_iso": statistically independent identically distributed residuals + with anisotropic covariance + - "indep_aniso": statistically independent residuals with anisotropic + covariances + For details, see: + N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning + Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 + """ + + # fmt: off + IID_ISO = "iid_iso" + INDEP_ANISO = "indep_aniso" + # fmt: on + + +@dataclass +class DensePoseUVConfidenceConfig: + """ + Configuration options for confidence on UV data + """ + + enabled: bool = False + # lower bound on UV confidences + epsilon: float = 0.01 + type: DensePoseUVConfidenceType = DensePoseUVConfidenceType.IID_ISO + + +@dataclass +class DensePoseSegmConfidenceConfig: + """ + Configuration options for confidence on segmentation + """ + + enabled: bool = False + # lower bound on confidence values + epsilon: float = 0.01 + + +@dataclass +class DensePoseConfidenceModelConfig: + """ + Configuration options for confidence models + """ + + # confidence for U and V values + uv_confidence: DensePoseUVConfidenceConfig + # segmentation confidence + segm_confidence: DensePoseSegmConfidenceConfig + + @staticmethod + def from_cfg(cfg: CfgNode) -> "DensePoseConfidenceModelConfig": + return DensePoseConfidenceModelConfig( + uv_confidence=DensePoseUVConfidenceConfig( + enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.ENABLED, + epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON, + type=DensePoseUVConfidenceType(cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE), + ), + segm_confidence=DensePoseSegmConfidenceConfig( + enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.ENABLED, + epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON, + ), + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/__init__.py new file mode 100644 index 0000000..a227360 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from .vertex_direct_embedder import VertexDirectEmbedder +from .vertex_feature_embedder import VertexFeatureEmbedder +from .embedder import Embedder diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/embedder.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/embedder.py new file mode 100644 index 0000000..7f52b06 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/embedder.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +import numpy as np +import pickle +from enum import Enum +from typing import Optional +import torch +from torch import nn + +from detectron2.config import CfgNode +from detectron2.utils.file_io import PathManager + +from .vertex_direct_embedder import VertexDirectEmbedder +from .vertex_feature_embedder import VertexFeatureEmbedder + + +class EmbedderType(Enum): + """ + Embedder type which defines how vertices are mapped into the embedding space: + - "vertex_direct": direct vertex embedding + - "vertex_feature": embedding vertex features + """ + + VERTEX_DIRECT = "vertex_direct" + VERTEX_FEATURE = "vertex_feature" + + +def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: + """ + Create an embedder based on the provided configuration + + Args: + embedder_spec (CfgNode): embedder configuration + embedder_dim (int): embedding space dimensionality + Return: + An embedder instance for the specified configuration + Raises ValueError, in case of unexpected embedder type + """ + embedder_type = EmbedderType(embedder_spec.TYPE) + if embedder_type == EmbedderType.VERTEX_DIRECT: + embedder = VertexDirectEmbedder( + num_vertices=embedder_spec.NUM_VERTICES, + embed_dim=embedder_dim, + ) + if embedder_spec.INIT_FILE != "": + embedder.load(embedder_spec.INIT_FILE) + elif embedder_type == EmbedderType.VERTEX_FEATURE: + embedder = VertexFeatureEmbedder( + num_vertices=embedder_spec.NUM_VERTICES, + feature_dim=embedder_spec.FEATURE_DIM, + embed_dim=embedder_dim, + train_features=embedder_spec.FEATURES_TRAINABLE, + ) + if embedder_spec.INIT_FILE != "": + embedder.load(embedder_spec.INIT_FILE) + else: + raise ValueError(f"Unexpected embedder type {embedder_type}") + + if not embedder_spec.IS_TRAINABLE: + embedder.requires_grad_(False) + + return embedder + + +class Embedder(nn.Module): + """ + Embedder module that serves as a container for embedders to use with different + meshes. Extends Module to automatically save / load state dict. + """ + + DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." + + def __init__(self, cfg: CfgNode): + """ + Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule + "embedder_{i}". + + Args: + cfg (CfgNode): configuration options + """ + super(Embedder, self).__init__() + self.mesh_names = set() + embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE + logger = logging.getLogger(__name__) + for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): + logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") + self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) + self.mesh_names.add(mesh_name) + if cfg.MODEL.WEIGHTS != "": + self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) + + def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): + if prefix is None: + prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX + state_dict = None + if fpath.endswith(".pkl"): + with PathManager.open(fpath, "rb") as hFile: + state_dict = pickle.load(hFile, encoding="latin1") # pyre-ignore[6] + else: + with PathManager.open(fpath, "rb") as hFile: + # pyre-fixme[6]: For 1st param expected `Union[PathLike[typing.Any], + # IO[bytes], str, BinaryIO]` but got `Union[IO[bytes], IO[str]]`. + state_dict = torch.load(hFile, map_location=torch.device("cpu")) + if state_dict is not None and "model" in state_dict: + state_dict_local = {} + for key in state_dict["model"]: + if key.startswith(prefix): + v_key = state_dict["model"][key] + if isinstance(v_key, np.ndarray): + v_key = torch.from_numpy(v_key) + state_dict_local[key[len(prefix) :]] = v_key + # non-strict loading to finetune on different meshes + self.load_state_dict(state_dict_local, strict=False) + + def forward(self, mesh_name: str) -> torch.Tensor: + """ + Produce vertex embeddings for the specific mesh; vertex embeddings are + a tensor of shape [N, D] where: + N = number of vertices + D = number of dimensions in the embedding space + Args: + mesh_name (str): name of a mesh for which to obtain vertex embeddings + Return: + Vertex embeddings, a tensor of shape [N, D] + """ + return getattr(self, f"embedder_{mesh_name}")() + + def has_embeddings(self, mesh_name: str) -> bool: + return hasattr(self, f"embedder_{mesh_name}") diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/utils.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/utils.py new file mode 100644 index 0000000..5c57998 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import torch +from torch.nn import functional as F + + +def squared_euclidean_distance_matrix(pts1: torch.Tensor, pts2: torch.Tensor) -> torch.Tensor: + """ + Get squared Euclidean Distance Matrix + Computes pairwise squared Euclidean distances between points + + Args: + pts1: Tensor [M x D], M is the number of points, D is feature dimensionality + pts2: Tensor [N x D], N is the number of points, D is feature dimensionality + + Return: + Tensor [M, N]: matrix of squared Euclidean distances; at index (m, n) + it contains || pts1[m] - pts2[n] ||^2 + """ + edm = torch.mm(-2 * pts1, pts2.t()) + edm += (pts1 * pts1).sum(1, keepdim=True) + (pts2 * pts2).sum(1, keepdim=True).t() + return edm.contiguous() + + +def normalize_embeddings(embeddings: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: + """ + Normalize N D-dimensional embedding vectors arranged in a tensor [N, D] + + Args: + embeddings (tensor [N, D]): N D-dimensional embedding vectors + epsilon (float): minimum value for a vector norm + Return: + Normalized embeddings (tensor [N, D]), such that L2 vector norms are all equal to 1. + """ + return embeddings / torch.clamp(embeddings.norm(p=None, dim=1, keepdim=True), min=epsilon) + + +def get_closest_vertices_mask_from_ES( + E: torch.Tensor, + S: torch.Tensor, + h: int, + w: int, + mesh_vertex_embeddings: torch.Tensor, + device: torch.device, +): + """ + Interpolate Embeddings and Segmentations to the size of a given bounding box, + and compute closest vertices and the segmentation mask + + Args: + E (tensor [1, D, H, W]): D-dimensional embedding vectors for every point of the + default-sized box + S (tensor [1, 2, H, W]): 2-dimensional segmentation mask for every point of the + default-sized box + h (int): height of the target bounding box + w (int): width of the target bounding box + mesh_vertex_embeddings (tensor [N, D]): vertex embeddings for a chosen mesh + N is the number of vertices in the mesh, D is feature dimensionality + device (torch.device): device to move the tensors to + Return: + Closest Vertices (tensor [h, w]), int, for every point of the resulting box + Segmentation mask (tensor [h, w]), boolean, for every point of the resulting box + """ + mesh_vertex_embeddings = mesh_vertex_embeddings[:6890,:] + embedding_resized = F.interpolate(E, size=(h, w), mode="bilinear")[0].to(device) + coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0].to(device) + mask = coarse_segm_resized.argmax(0) > 0 + closest_vertices = torch.zeros(mask.shape, dtype=torch.long, device=device) + all_embeddings = embedding_resized[:, mask].t() + size_chunk = 10_000 # Chunking to avoid possible OOM + edm = [] + if len(all_embeddings) == 0: + return closest_vertices, mask + for chunk in range((len(all_embeddings) - 1) // size_chunk + 1): + chunk_embeddings = all_embeddings[size_chunk * chunk : size_chunk * (chunk + 1)] + edm.append( + torch.argmin( + squared_euclidean_distance_matrix(chunk_embeddings, mesh_vertex_embeddings), dim=1 + ) + ) + closest_vertices[mask] = torch.cat(edm) + return closest_vertices, mask diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_direct_embedder.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_direct_embedder.py new file mode 100644 index 0000000..60fba27 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_direct_embedder.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import pickle +import torch +from torch import nn + +from detectron2.utils.file_io import PathManager + +from .utils import normalize_embeddings + + +class VertexDirectEmbedder(nn.Module): + """ + Class responsible for embedding vertices. Vertex embeddings take + the form of a tensor of size [N, D], where + N = number of vertices + D = number of dimensions in the embedding space + """ + + def __init__(self, num_vertices: int, embed_dim: int): + """ + Initialize embedder, set random embeddings + + Args: + num_vertices (int): number of vertices to embed + embed_dim (int): number of dimensions in the embedding space + """ + super(VertexDirectEmbedder, self).__init__() + self.embeddings = nn.Parameter(torch.Tensor(num_vertices, embed_dim)) + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + """ + Reset embeddings to random values + """ + self.embeddings.zero_() + + def forward(self) -> torch.Tensor: + """ + Produce vertex embeddings, a tensor of shape [N, D] where: + N = number of vertices + D = number of dimensions in the embedding space + + Return: + Full vertex embeddings, a tensor of shape [N, D] + """ + return normalize_embeddings(self.embeddings) + + @torch.no_grad() + def load(self, fpath: str): + """ + Load data from a file + + Args: + fpath (str): file path to load data from + """ + with PathManager.open(fpath, "rb") as hFile: + data = pickle.load(hFile) # pyre-ignore[6] + for name in ["embeddings"]: + if name in data: + getattr(self, name).copy_( + torch.tensor(data[name]).float().to(device=getattr(self, name).device) + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_feature_embedder.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_feature_embedder.py new file mode 100644 index 0000000..dcb2f20 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/cse/vertex_feature_embedder.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import pickle +import torch +from torch import nn + +from detectron2.utils.file_io import PathManager + +from .utils import normalize_embeddings + + +class VertexFeatureEmbedder(nn.Module): + """ + Class responsible for embedding vertex features. Mapping from + feature space to the embedding space is a tensor of size [K, D], where + K = number of dimensions in the feature space + D = number of dimensions in the embedding space + Vertex features is a tensor of size [N, K], where + N = number of vertices + K = number of dimensions in the feature space + Vertex embeddings are computed as F * E = tensor of size [N, D] + """ + + def __init__( + self, num_vertices: int, feature_dim: int, embed_dim: int, train_features: bool = False + ): + """ + Initialize embedder, set random embeddings + + Args: + num_vertices (int): number of vertices to embed + feature_dim (int): number of dimensions in the feature space + embed_dim (int): number of dimensions in the embedding space + train_features (bool): determines whether vertex features should + be trained (default: False) + """ + super(VertexFeatureEmbedder, self).__init__() + if train_features: + self.features = nn.Parameter(torch.Tensor(num_vertices, feature_dim)) + else: + self.register_buffer("features", torch.Tensor(num_vertices, feature_dim)) + self.embeddings = nn.Parameter(torch.Tensor(feature_dim, embed_dim)) + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + self.features.zero_() + self.embeddings.zero_() + + def forward(self) -> torch.Tensor: + """ + Produce vertex embeddings, a tensor of shape [N, D] where: + N = number of vertices + D = number of dimensions in the embedding space + + Return: + Full vertex embeddings, a tensor of shape [N, D] + """ + return normalize_embeddings(torch.mm(self.features, self.embeddings)) + + @torch.no_grad() + def load(self, fpath: str): + """ + Load data from a file + + Args: + fpath (str): file path to load data from + """ + with PathManager.open(fpath, "rb") as hFile: + data = pickle.load(hFile) # pyre-ignore[6] + for name in ["features", "embeddings"]: + if name in data: + getattr(self, name).copy_( + torch.tensor(data[name]).float().to(device=getattr(self, name).device) + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/densepose_checkpoint.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/densepose_checkpoint.py new file mode 100644 index 0000000..8c2b4f2 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/densepose_checkpoint.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from collections import OrderedDict + +from detectron2.checkpoint import DetectionCheckpointer + + +def _rename_HRNet_weights(weights): + # We detect and rename HRNet weights for DensePose. 1956 and 1716 are values that are + # common to all HRNet pretrained weights, and should be enough to accurately identify them + if ( + len(weights["model"].keys()) == 1956 + and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716 + ): + hrnet_weights = OrderedDict() + for k in weights["model"].keys(): + hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k] + return {"model": hrnet_weights} + else: + return weights + + +class DensePoseCheckpointer(DetectionCheckpointer): + """ + Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights + """ + + def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): + super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables) + + def _load_file(self, filename: str) -> object: + """ + Adding hrnet support + """ + weights = super()._load_file(filename) + return _rename_HRNet_weights(weights) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/filter.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/filter.py new file mode 100644 index 0000000..18a8567 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/filter.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import List +import torch + +from detectron2.config import CfgNode +from detectron2.structures import Instances +from detectron2.structures.boxes import matched_pairwise_iou + + +class DensePoseDataFilter(object): + def __init__(self, cfg: CfgNode): + self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD + self.keep_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS + + @torch.no_grad() + def __call__(self, features: List[torch.Tensor], proposals_with_targets: List[Instances]): + """ + Filters proposals with targets to keep only the ones relevant for + DensePose training + + Args: + features (list[Tensor]): input data as a list of features, + each feature is a tensor. Axis 0 represents the number of + images `N` in the input data; axes 1-3 are channels, + height, and width, which may vary between features + (e.g., if a feature pyramid is used). + proposals_with_targets (list[Instances]): length `N` list of + `Instances`. The i-th `Instances` contains instances + (proposals, GT) for the i-th input image, + Returns: + list[Tensor]: filtered features + list[Instances]: filtered proposals + """ + proposals_filtered = [] + # TODO: the commented out code was supposed to correctly deal with situations + # where no valid DensePose GT is available for certain images. The corresponding + # image features were sliced and proposals were filtered. This led to performance + # deterioration, both in terms of runtime and in terms of evaluation results. + # + # feature_mask = torch.ones( + # len(proposals_with_targets), + # dtype=torch.bool, + # device=features[0].device if len(features) > 0 else torch.device("cpu"), + # ) + for i, proposals_per_image in enumerate(proposals_with_targets): + if not proposals_per_image.has("gt_densepose") and ( + not proposals_per_image.has("gt_masks") or not self.keep_masks + ): + # feature_mask[i] = 0 + continue + gt_boxes = proposals_per_image.gt_boxes + est_boxes = proposals_per_image.proposal_boxes + # apply match threshold for densepose head + iou = matched_pairwise_iou(gt_boxes, est_boxes) + iou_select = iou > self.iou_threshold + proposals_per_image = proposals_per_image[iou_select] # pyre-ignore[6] + + N_gt_boxes = len(proposals_per_image.gt_boxes) + assert N_gt_boxes == len(proposals_per_image.proposal_boxes), ( + f"The number of GT boxes {N_gt_boxes} is different from the " + f"number of proposal boxes {len(proposals_per_image.proposal_boxes)}" + ) + # filter out any target without suitable annotation + if self.keep_masks: + gt_masks = ( + proposals_per_image.gt_masks + if hasattr(proposals_per_image, "gt_masks") + else [None] * N_gt_boxes + ) + else: + gt_masks = [None] * N_gt_boxes + gt_densepose = ( + proposals_per_image.gt_densepose + if hasattr(proposals_per_image, "gt_densepose") + else [None] * N_gt_boxes + ) + assert len(gt_masks) == N_gt_boxes + assert len(gt_densepose) == N_gt_boxes + selected_indices = [ + i + for i, (dp_target, mask_target) in enumerate(zip(gt_densepose, gt_masks)) + if (dp_target is not None) or (mask_target is not None) + ] + # if not len(selected_indices): + # feature_mask[i] = 0 + # continue + if len(selected_indices) != N_gt_boxes: + proposals_per_image = proposals_per_image[selected_indices] # pyre-ignore[6] + assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes) + proposals_filtered.append(proposals_per_image) + # features_filtered = [feature[feature_mask] for feature in features] + # return features_filtered, proposals_filtered + return features, proposals_filtered diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/hrfpn.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/hrfpn.py new file mode 100644 index 0000000..08ec420 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/hrfpn.py @@ -0,0 +1,182 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +MIT License +Copyright (c) 2019 Microsoft +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone import BACKBONE_REGISTRY +from detectron2.modeling.backbone.backbone import Backbone + +from .hrnet import build_pose_hrnet_backbone + + +class HRFPN(Backbone): + """HRFPN (High Resolution Feature Pyramids) + Transforms outputs of HRNet backbone so they are suitable for the ROI_heads + arXiv: https://arxiv.org/abs/1904.04514 + Adapted from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/hrfpn.py + Args: + bottom_up: (list) output of HRNet + in_features (list): names of the input features (output of HRNet) + in_channels (list): number of channels for each branch + out_channels (int): output channels of feature pyramids + n_out_features (int): number of output stages + pooling (str): pooling for generating feature pyramids (from {MAX, AVG}) + share_conv (bool): Have one conv per output, or share one with all the outputs + """ + + def __init__( + self, + bottom_up, + in_features, + n_out_features, + in_channels, + out_channels, + pooling="AVG", + share_conv=False, + ): + super(HRFPN, self).__init__() + assert isinstance(in_channels, list) + self.bottom_up = bottom_up + self.in_features = in_features + self.n_out_features = n_out_features + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.share_conv = share_conv + + if self.share_conv: + self.fpn_conv = nn.Conv2d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + else: + self.fpn_conv = nn.ModuleList() + for _ in range(self.n_out_features): + self.fpn_conv.append( + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + ) + + # Custom change: Replaces a simple bilinear interpolation + self.interp_conv = nn.ModuleList() + for i in range(len(self.in_features)): + self.interp_conv.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels[i], + out_channels=in_channels[i], + kernel_size=4, + stride=2**i, + padding=0, + output_padding=0, + bias=False, + ), + nn.BatchNorm2d(in_channels[i], momentum=0.1), + nn.ReLU(inplace=True), + ) + ) + + # Custom change: Replaces a couple (reduction conv + pooling) by one conv + self.reduction_pooling_conv = nn.ModuleList() + for i in range(self.n_out_features): + self.reduction_pooling_conv.append( + nn.Sequential( + nn.Conv2d(sum(in_channels), out_channels, kernel_size=2**i, stride=2**i), + nn.BatchNorm2d(out_channels, momentum=0.1), + nn.ReLU(inplace=True), + ) + ) + + if pooling == "MAX": + self.pooling = F.max_pool2d + else: + self.pooling = F.avg_pool2d + + self._out_features = [] + self._out_feature_channels = {} + self._out_feature_strides = {} + + for i in range(self.n_out_features): + self._out_features.append("p%d" % (i + 1)) + self._out_feature_channels.update({self._out_features[-1]: self.out_channels}) + self._out_feature_strides.update({self._out_features[-1]: 2 ** (i + 2)}) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, inputs): + bottom_up_features = self.bottom_up(inputs) + assert len(bottom_up_features) == len(self.in_features) + inputs = [bottom_up_features[f] for f in self.in_features] + + outs = [] + for i in range(len(inputs)): + outs.append(self.interp_conv[i](inputs[i])) + shape_2 = min(o.shape[2] for o in outs) + shape_3 = min(o.shape[3] for o in outs) + out = torch.cat([o[:, :, :shape_2, :shape_3] for o in outs], dim=1) + outs = [] + for i in range(self.n_out_features): + outs.append(self.reduction_pooling_conv[i](out)) + for i in range(len(outs)): # Make shapes consistent + outs[-1 - i] = outs[-1 - i][ + :, :, : outs[-1].shape[2] * 2**i, : outs[-1].shape[3] * 2**i + ] + outputs = [] + for i in range(len(outs)): + if self.share_conv: + outputs.append(self.fpn_conv(outs[i])) + else: + outputs.append(self.fpn_conv[i](outs[i])) + + assert len(self._out_features) == len(outputs) + return dict(zip(self._out_features, outputs)) + + +@BACKBONE_REGISTRY.register() +def build_hrfpn_backbone(cfg, input_shape: ShapeSpec) -> HRFPN: + + in_channels = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS + in_features = ["p%d" % (i + 1) for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES)] + n_out_features = len(cfg.MODEL.ROI_HEADS.IN_FEATURES) + out_channels = cfg.MODEL.HRNET.HRFPN.OUT_CHANNELS + hrnet = build_pose_hrnet_backbone(cfg, input_shape) + hrfpn = HRFPN( + hrnet, + in_features, + n_out_features, + in_channels, + out_channels, + pooling="AVG", + share_conv=False, + ) + + return hrfpn diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/hrnet.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/hrnet.py new file mode 100644 index 0000000..ca24671 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/hrnet.py @@ -0,0 +1,474 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# Adapted from https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py # noqa +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import, division, print_function +import logging +import torch.nn as nn + +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone import BACKBONE_REGISTRY +from detectron2.modeling.backbone.backbone import Backbone + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + +__all__ = ["build_pose_hrnet_backbone", "PoseHigherResolutionNet"] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + """HighResolutionModule + Building block of the PoseHigherResolutionNet (see lower) + arXiv: https://arxiv.org/abs/1908.10357 + Args: + num_branches (int): number of branches of the modyle + blocks (str): type of block of the module + num_blocks (int): number of blocks of the module + num_inchannels (int): number of input channels of the module + num_channels (list): number of channels of each branch + multi_scale_output (bool): only used by the last module of PoseHigherResolutionNet + """ + + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + multi_scale_output=True, + ): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if ( + stride != 1 + or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion + ): + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample) + ) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for _ in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d(num_outchannels_conv3x3), + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + z = self.fuse_layers[i][j](x[j])[:, :, : y.shape[2], : y.shape[3]] + y = y + z + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} + + +class PoseHigherResolutionNet(Backbone): + """PoseHigherResolutionNet + Composed of several HighResolutionModule tied together with ConvNets + Adapted from the GitHub version to fit with HRFPN and the Detectron2 infrastructure + arXiv: https://arxiv.org/abs/1908.10357 + """ + + def __init__(self, cfg, **kwargs): + self.inplanes = cfg.MODEL.HRNET.STEM_INPLANES + super(PoseHigherResolutionNet, self).__init__() + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, 64, 4) + + self.stage2_cfg = cfg.MODEL.HRNET.STAGE2 + num_channels = self.stage2_cfg.NUM_CHANNELS + block = blocks_dict[self.stage2_cfg.BLOCK] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg.MODEL.HRNET.STAGE3 + num_channels = self.stage3_cfg.NUM_CHANNELS + block = blocks_dict[self.stage3_cfg.BLOCK] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg.MODEL.HRNET.STAGE4 + num_channels = self.stage4_cfg.NUM_CHANNELS + block = blocks_dict[self.stage4_cfg.BLOCK] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True + ) + + self._out_features = [] + self._out_feature_channels = {} + self._out_feature_strides = {} + + for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES): + self._out_features.append("p%d" % (i + 1)) + self._out_feature_channels.update( + {self._out_features[-1]: cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS[i]} + ) + self._out_feature_strides.update({self._out_features[-1]: 1}) + + def _get_deconv_cfg(self, deconv_kernel): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False, + ), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True), + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = ( + num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + ) + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True), + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config["NUM_MODULES"] + num_branches = layer_config["NUM_BRANCHES"] + num_blocks = layer_config["NUM_BLOCKS"] + num_channels = layer_config["NUM_CHANNELS"] + block = blocks_dict[layer_config["BLOCK"]] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + reset_multi_scale_output, + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg.NUM_BRANCHES): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg.NUM_BRANCHES): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg.NUM_BRANCHES): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + assert len(self._out_features) == len(y_list) + return dict(zip(self._out_features, y_list)) # final_outputs + + +@BACKBONE_REGISTRY.register() +def build_pose_hrnet_backbone(cfg, input_shape: ShapeSpec): + model = PoseHigherResolutionNet(cfg) + return model diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/inference.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/inference.py new file mode 100644 index 0000000..8104964 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/inference.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from dataclasses import fields +from typing import Any, List +import torch + +from detectron2.structures import Instances + + +def densepose_inference(densepose_predictor_output: Any, detections: List[Instances]) -> None: + """ + Splits DensePose predictor outputs into chunks, each chunk corresponds to + detections on one image. Predictor output chunks are stored in `pred_densepose` + attribute of the corresponding `Instances` object. + + Args: + densepose_predictor_output: a dataclass instance (can be of different types, + depending on predictor used for inference). Each field can be `None` + (if the corresponding output was not inferred) or a tensor of size + [N, ...], where N = N_1 + N_2 + .. + N_k is a total number of + detections on all images, N_1 is the number of detections on image 1, + N_2 is the number of detections on image 2, etc. + detections: a list of objects of type `Instance`, k-th object corresponds + to detections on k-th image. + """ + k = 0 + for detection_i in detections: + if densepose_predictor_output is None: + # don't add `pred_densepose` attribute + continue + n_i = detection_i.__len__() + + PredictorOutput = type(densepose_predictor_output) + output_i_dict = {} + # we assume here that `densepose_predictor_output` is a dataclass object + for field in fields(densepose_predictor_output): + field_value = getattr(densepose_predictor_output, field.name) + # slice tensors + if isinstance(field_value, torch.Tensor): + output_i_dict[field.name] = field_value[k : k + n_i] + # leave others as is + else: + output_i_dict[field.name] = field_value + detection_i.pred_densepose = PredictorOutput(**output_i_dict) + k += n_i diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/__init__.py new file mode 100644 index 0000000..e5c5937 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .chart import DensePoseChartLoss +from .chart_with_confidences import DensePoseChartWithConfidenceLoss +from .cse import DensePoseCseLoss +from .registry import DENSEPOSE_LOSS_REGISTRY + + +__all__ = [ + "DensePoseChartLoss", + "DensePoseChartWithConfidenceLoss", + "DensePoseCseLoss", + "DENSEPOSE_LOSS_REGISTRY", +] diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart.py new file mode 100644 index 0000000..02cdae8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart.py @@ -0,0 +1,291 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any, List +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from .mask_or_segm import MaskOrSegmentationLoss +from .registry import DENSEPOSE_LOSS_REGISTRY +from .utils import ( + BilinearInterpolationHelper, + ChartBasedAnnotationsAccumulator, + LossDict, + extract_packed_annotations_from_matches, +) + + +@DENSEPOSE_LOSS_REGISTRY.register() +class DensePoseChartLoss: + """ + DensePose loss for chart-based training. A mesh is split into charts, + each chart is given a label (I) and parametrized by 2 coordinates referred to + as U and V. Ground truth consists of a number of points annotated with + I, U and V values and coarse segmentation S defined for all pixels of the + object bounding box. In some cases (see `COARSE_SEGM_TRAINED_BY_MASKS`), + semantic segmentation annotations can be used as ground truth inputs as well. + + Estimated values are tensors: + * U coordinates, tensor of shape [N, C, S, S] + * V coordinates, tensor of shape [N, C, S, S] + * fine segmentation estimates, tensor of shape [N, C, S, S] with raw unnormalized + scores for each fine segmentation label at each location + * coarse segmentation estimates, tensor of shape [N, D, S, S] with raw unnormalized + scores for each coarse segmentation label at each location + where N is the number of detections, C is the number of fine segmentation + labels, S is the estimate size ( = width = height) and D is the number of + coarse segmentation channels. + + The losses are: + * regression (smooth L1) loss for U and V coordinates + * cross entropy loss for fine (I) and coarse (S) segmentations + Each loss has an associated weight + """ + + def __init__(self, cfg: CfgNode): + """ + Initialize chart-based loss from configuration options + + Args: + cfg (CfgNode): configuration options + """ + # fmt: off + self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE + self.w_points = cfg.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS + self.w_part = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS + self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS + self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + # fmt: on + self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS + self.segm_loss = MaskOrSegmentationLoss(cfg) + + def __call__( + self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any, **kwargs + ) -> LossDict: + """ + Produce chart-based DensePose losses + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: an object of a dataclass that contains predictor outputs + with estimated values; assumed to have the following attributes: + * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] + * fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S] + * u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S] + * v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S] + where N is the number of detections, C is the number of fine segmentation + labels, S is the estimate size ( = width = height) and D is the number of + coarse segmentation channels. + + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_U`: smooth L1 loss for U coordinate estimates + * `loss_densepose_V`: smooth L1 loss for V coordinate estimates + * `loss_densepose_I`: cross entropy for raw unnormalized scores for fine + segmentation estimates given ground truth labels; + * `loss_densepose_S`: cross entropy for raw unnormalized scores for coarse + segmentation estimates given ground truth labels; + """ + # densepose outputs are computed for all images and all bounding boxes; + # i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively, + # the outputs will have size(0) == 3+1+2+1 == 7 + + if not len(proposals_with_gt): + return self.produce_fake_densepose_losses(densepose_predictor_outputs) + + accumulator = ChartBasedAnnotationsAccumulator() + packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator) + + # NOTE: we need to keep the same computation graph on all the GPUs to + # perform reduction properly. Hence even if we have no data on one + # of the GPUs, we still need to generate the computation graph. + # Add fake (zero) loss in the form Tensor.sum() * 0 + if packed_annotations is None: + return self.produce_fake_densepose_losses(densepose_predictor_outputs) + + h, w = densepose_predictor_outputs.u.shape[2:] + interpolator = BilinearInterpolationHelper.from_matches( + packed_annotations, + (h, w), + ) + + j_valid_fg = interpolator.j_valid * ( # pyre-ignore[16] + packed_annotations.fine_segm_labels_gt > 0 + ) + # pyre-fixme[6]: For 1st param expected `Tensor` but got `int`. + if not torch.any(j_valid_fg): + return self.produce_fake_densepose_losses(densepose_predictor_outputs) + + losses_uv = self.produce_densepose_losses_uv( + proposals_with_gt, + densepose_predictor_outputs, + packed_annotations, + interpolator, + j_valid_fg, # pyre-ignore[6] + ) + + losses_segm = self.produce_densepose_losses_segm( + proposals_with_gt, + densepose_predictor_outputs, + packed_annotations, + interpolator, + j_valid_fg, # pyre-ignore[6] + ) + + return {**losses_uv, **losses_segm} + + def produce_fake_densepose_losses(self, densepose_predictor_outputs: Any) -> LossDict: + """ + Fake losses for fine segmentation and U/V coordinates. These are used when + no suitable ground truth data was found in a batch. The loss has a value 0 + and is primarily used to construct the computation graph, so that + `DistributedDataParallel` has similar graphs on all GPUs and can perform + reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S] + * u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S] + * v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_U`: has value 0 + * `loss_densepose_V`: has value 0 + * `loss_densepose_I`: has value 0 + * `loss_densepose_S`: has value 0 + """ + losses_uv = self.produce_fake_densepose_losses_uv(densepose_predictor_outputs) + losses_segm = self.produce_fake_densepose_losses_segm(densepose_predictor_outputs) + return {**losses_uv, **losses_segm} + + def produce_fake_densepose_losses_uv(self, densepose_predictor_outputs: Any) -> LossDict: + """ + Fake losses for U/V coordinates. These are used when no suitable ground + truth data was found in a batch. The loss has a value 0 + and is primarily used to construct the computation graph, so that + `DistributedDataParallel` has similar graphs on all GPUs and can perform + reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S] + * v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_U`: has value 0 + * `loss_densepose_V`: has value 0 + """ + return { + "loss_densepose_U": densepose_predictor_outputs.u.sum() * 0, + "loss_densepose_V": densepose_predictor_outputs.v.sum() * 0, + } + + def produce_fake_densepose_losses_segm(self, densepose_predictor_outputs: Any) -> LossDict: + """ + Fake losses for fine / coarse segmentation. These are used when + no suitable ground truth data was found in a batch. The loss has a value 0 + and is primarily used to construct the computation graph, so that + `DistributedDataParallel` has similar graphs on all GPUs and can perform + reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S] + * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_I`: has value 0 + * `loss_densepose_S`: has value 0, added only if `segm_trained_by_masks` is False + """ + losses = { + "loss_densepose_I": densepose_predictor_outputs.fine_segm.sum() * 0, + "loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs), + } + return losses + + def produce_densepose_losses_uv( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: Any, + interpolator: BilinearInterpolationHelper, + j_valid_fg: torch.Tensor, + ) -> LossDict: + """ + Compute losses for U/V coordinates: smooth L1 loss between + estimated coordinates and the ground truth. + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S] + * v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_U`: smooth L1 loss for U coordinate estimates + * `loss_densepose_V`: smooth L1 loss for V coordinate estimates + """ + u_gt = packed_annotations.u_gt[j_valid_fg] + u_est = interpolator.extract_at_points(densepose_predictor_outputs.u)[j_valid_fg] + v_gt = packed_annotations.v_gt[j_valid_fg] + v_est = interpolator.extract_at_points(densepose_predictor_outputs.v)[j_valid_fg] + return { + "loss_densepose_U": F.smooth_l1_loss(u_est, u_gt, reduction="sum") * self.w_points, + "loss_densepose_V": F.smooth_l1_loss(v_est, v_gt, reduction="sum") * self.w_points, + } + + def produce_densepose_losses_segm( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: Any, + interpolator: BilinearInterpolationHelper, + j_valid_fg: torch.Tensor, + ) -> LossDict: + """ + Losses for fine / coarse segmentation: cross-entropy + for segmentation unnormalized scores given ground truth labels at + annotated points for fine segmentation and dense mask annotations + for coarse segmentation. + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S] + * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_I`: cross entropy for raw unnormalized scores for fine + segmentation estimates given ground truth labels + * `loss_densepose_S`: cross entropy for raw unnormalized scores for coarse + segmentation estimates given ground truth labels; + may be included if coarse segmentation is only trained + using DensePose ground truth; if additional supervision through + instance segmentation data is performed (`segm_trained_by_masks` is True), + this loss is handled by `produce_mask_losses` instead + """ + fine_segm_gt = packed_annotations.fine_segm_labels_gt[ + interpolator.j_valid # pyre-ignore[16] + ] + fine_segm_est = interpolator.extract_at_points( + densepose_predictor_outputs.fine_segm, + slice_fine_segm=slice(None), + w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16] + w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16] + w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16] + w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16] + )[interpolator.j_valid, :] + return { + "loss_densepose_I": F.cross_entropy(fine_segm_est, fine_segm_gt.long()) * self.w_part, + "loss_densepose_S": self.segm_loss( + proposals_with_gt, densepose_predictor_outputs, packed_annotations + ) + * self.w_segm, + } diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart_with_confidences.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart_with_confidences.py new file mode 100644 index 0000000..78ce7c6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/chart_with_confidences.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import Any, List +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from .. import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType +from .chart import DensePoseChartLoss +from .registry import DENSEPOSE_LOSS_REGISTRY +from .utils import BilinearInterpolationHelper, LossDict + + +@DENSEPOSE_LOSS_REGISTRY.register() +class DensePoseChartWithConfidenceLoss(DensePoseChartLoss): + """ """ + + def __init__(self, cfg: CfgNode): + super().__init__(cfg) + self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg) + if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO: + self.uv_loss_with_confidences = IIDIsotropicGaussianUVLoss( + self.confidence_model_cfg.uv_confidence.epsilon + ) + elif self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.INDEP_ANISO: + self.uv_loss_with_confidences = IndepAnisotropicGaussianUVLoss( + self.confidence_model_cfg.uv_confidence.epsilon + ) + + def produce_fake_densepose_losses_uv(self, densepose_predictor_outputs: Any) -> LossDict: + """ + Overrides fake losses for fine segmentation and U/V coordinates to + include computation graphs for additional confidence parameters. + These are used when no suitable ground truth data was found in a batch. + The loss has a value 0 and is primarily used to construct the computation graph, + so that `DistributedDataParallel` has similar graphs on all GPUs and can + perform reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have the following attributes: + * fine_segm - fine segmentation estimates, tensor of shape [N, C, S, S] + * u - U coordinate estimates per fine labels, tensor of shape [N, C, S, S] + * v - V coordinate estimates per fine labels, tensor of shape [N, C, S, S] + Return: + dict: str -> tensor: dict of losses with the following entries: + * `loss_densepose_U`: has value 0 + * `loss_densepose_V`: has value 0 + * `loss_densepose_I`: has value 0 + """ + conf_type = self.confidence_model_cfg.uv_confidence.type + if self.confidence_model_cfg.uv_confidence.enabled: + loss_uv = ( + densepose_predictor_outputs.u.sum() + densepose_predictor_outputs.v.sum() + ) * 0 + if conf_type == DensePoseUVConfidenceType.IID_ISO: + loss_uv += densepose_predictor_outputs.sigma_2.sum() * 0 + elif conf_type == DensePoseUVConfidenceType.INDEP_ANISO: + loss_uv += ( + densepose_predictor_outputs.sigma_2.sum() + + densepose_predictor_outputs.kappa_u.sum() + + densepose_predictor_outputs.kappa_v.sum() + ) * 0 + return {"loss_densepose_UV": loss_uv} + else: + return super().produce_fake_densepose_losses_uv(densepose_predictor_outputs) + + def produce_densepose_losses_uv( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: Any, + interpolator: BilinearInterpolationHelper, + j_valid_fg: torch.Tensor, + ) -> LossDict: + conf_type = self.confidence_model_cfg.uv_confidence.type + if self.confidence_model_cfg.uv_confidence.enabled: + u_gt = packed_annotations.u_gt[j_valid_fg] + u_est = interpolator.extract_at_points(densepose_predictor_outputs.u)[j_valid_fg] + v_gt = packed_annotations.v_gt[j_valid_fg] + v_est = interpolator.extract_at_points(densepose_predictor_outputs.v)[j_valid_fg] + sigma_2_est = interpolator.extract_at_points(densepose_predictor_outputs.sigma_2)[ + j_valid_fg + ] + if conf_type == DensePoseUVConfidenceType.IID_ISO: + return { + "loss_densepose_UV": ( + self.uv_loss_with_confidences(u_est, v_est, sigma_2_est, u_gt, v_gt) + * self.w_points + ) + } + elif conf_type in [DensePoseUVConfidenceType.INDEP_ANISO]: + kappa_u_est = interpolator.extract_at_points(densepose_predictor_outputs.kappa_u)[ + j_valid_fg + ] + kappa_v_est = interpolator.extract_at_points(densepose_predictor_outputs.kappa_v)[ + j_valid_fg + ] + return { + "loss_densepose_UV": ( + self.uv_loss_with_confidences( + u_est, v_est, sigma_2_est, kappa_u_est, kappa_v_est, u_gt, v_gt + ) + * self.w_points + ) + } + return super().produce_densepose_losses_uv( + proposals_with_gt, + densepose_predictor_outputs, + packed_annotations, + interpolator, + j_valid_fg, + ) + + +class IIDIsotropicGaussianUVLoss(nn.Module): + """ + Loss for the case of iid residuals with isotropic covariance: + $Sigma_i = sigma_i^2 I$ + The loss (negative log likelihood) is then: + $1/2 sum_{i=1}^n (log(2 pi) + 2 log sigma_i^2 + ||delta_i||^2 / sigma_i^2)$, + where $delta_i=(u - u', v - v')$ is a 2D vector containing UV coordinates + difference between estimated and ground truth UV values + For details, see: + N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning + Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 + """ + + def __init__(self, sigma_lower_bound: float): + super(IIDIsotropicGaussianUVLoss, self).__init__() + self.sigma_lower_bound = sigma_lower_bound + self.log2pi = math.log(2 * math.pi) + + def forward( + self, + u: torch.Tensor, + v: torch.Tensor, + sigma_u: torch.Tensor, + target_u: torch.Tensor, + target_v: torch.Tensor, + ): + # compute $\sigma_i^2$ + # use sigma_lower_bound to avoid degenerate solution for variance + # (sigma -> 0) + sigma2 = F.softplus(sigma_u) + self.sigma_lower_bound + # compute \|delta_i\|^2 + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + delta_t_delta = (u - target_u) ** 2 + (v - target_v) ** 2 + # the total loss from the formula above: + loss = 0.5 * (self.log2pi + 2 * torch.log(sigma2) + delta_t_delta / sigma2) + return loss.sum() + + +class IndepAnisotropicGaussianUVLoss(nn.Module): + """ + Loss for the case of independent residuals with anisotropic covariances: + $Sigma_i = sigma_i^2 I + r_i r_i^T$ + The loss (negative log likelihood) is then: + $1/2 sum_{i=1}^n (log(2 pi) + + log sigma_i^2 (sigma_i^2 + ||r_i||^2) + + ||delta_i||^2 / sigma_i^2 + - ^2 / (sigma_i^2 * (sigma_i^2 + ||r_i||^2)))$, + where $delta_i=(u - u', v - v')$ is a 2D vector containing UV coordinates + difference between estimated and ground truth UV values + For details, see: + N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning + Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 + """ + + def __init__(self, sigma_lower_bound: float): + super(IndepAnisotropicGaussianUVLoss, self).__init__() + self.sigma_lower_bound = sigma_lower_bound + self.log2pi = math.log(2 * math.pi) + + def forward( + self, + u: torch.Tensor, + v: torch.Tensor, + sigma_u: torch.Tensor, + kappa_u_est: torch.Tensor, + kappa_v_est: torch.Tensor, + target_u: torch.Tensor, + target_v: torch.Tensor, + ): + # compute $\sigma_i^2$ + sigma2 = F.softplus(sigma_u) + self.sigma_lower_bound + # compute \|r_i\|^2 + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + r_sqnorm2 = kappa_u_est**2 + kappa_v_est**2 + delta_u = u - target_u + delta_v = v - target_v + # compute \|delta_i\|^2 + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + delta_sqnorm = delta_u**2 + delta_v**2 + delta_u_r_u = delta_u * kappa_u_est + delta_v_r_v = delta_v * kappa_v_est + # compute the scalar product + delta_r = delta_u_r_u + delta_v_r_v + # compute squared scalar product ^2 + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + delta_r_sqnorm = delta_r**2 + denom2 = sigma2 * (sigma2 + r_sqnorm2) + loss = 0.5 * ( + self.log2pi + torch.log(denom2) + delta_sqnorm / sigma2 - delta_r_sqnorm / denom2 + ) + return loss.sum() diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cse.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cse.py new file mode 100644 index 0000000..dd561ad --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cse.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, List +from torch import nn + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from .cycle_pix2shape import PixToShapeCycleLoss +from .cycle_shape2shape import ShapeToShapeCycleLoss +from .embed import EmbeddingLoss +from .embed_utils import CseAnnotationsAccumulator +from .mask_or_segm import MaskOrSegmentationLoss +from .registry import DENSEPOSE_LOSS_REGISTRY +from .soft_embed import SoftEmbeddingLoss +from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches + + +@DENSEPOSE_LOSS_REGISTRY.register() +class DensePoseCseLoss: + """ """ + + _EMBED_LOSS_REGISTRY = { + EmbeddingLoss.__name__: EmbeddingLoss, + SoftEmbeddingLoss.__name__: SoftEmbeddingLoss, + } + + def __init__(self, cfg: CfgNode): + """ + Initialize CSE loss from configuration options + + Args: + cfg (CfgNode): configuration options + """ + self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS + self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT + self.segm_loss = MaskOrSegmentationLoss(cfg) + self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg) + self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED + if self.do_shape2shape: + self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT + self.shape2shape_loss = ShapeToShapeCycleLoss(cfg) + self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED + if self.do_pix2shape: + self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT + self.pix2shape_loss = PixToShapeCycleLoss(cfg) + + @classmethod + def create_embed_loss(cls, cfg: CfgNode): + # registry not used here, since embedding losses are currently local + # and are not used anywhere else + return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg) + + def __call__( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + embedder: nn.Module, + ) -> LossDict: + if not len(proposals_with_gt): + return self.produce_fake_losses(densepose_predictor_outputs, embedder) + accumulator = CseAnnotationsAccumulator() + packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator) + if packed_annotations is None: + return self.produce_fake_losses(densepose_predictor_outputs, embedder) + h, w = densepose_predictor_outputs.embedding.shape[2:] + interpolator = BilinearInterpolationHelper.from_matches( + packed_annotations, + (h, w), + ) + meshid_to_embed_losses = self.embed_loss( + proposals_with_gt, + densepose_predictor_outputs, + packed_annotations, + interpolator, + embedder, + ) + embed_loss_dict = { + f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid] + for meshid in meshid_to_embed_losses + } + all_loss_dict = { + "loss_densepose_S": self.w_segm + * self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations), + **embed_loss_dict, + } + if self.do_shape2shape: + all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder) + if self.do_pix2shape: + all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss( + proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder + ) + return all_loss_dict + + def produce_fake_losses( + self, densepose_predictor_outputs: Any, embedder: nn.Module + ) -> LossDict: + meshname_to_embed_losses = self.embed_loss.fake_values( + densepose_predictor_outputs, embedder=embedder + ) + embed_loss_dict = { + f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name] + for mesh_name in meshname_to_embed_losses + } + all_loss_dict = { + "loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs), + **embed_loss_dict, + } + if self.do_shape2shape: + all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder) + if self.do_pix2shape: + all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value( + densepose_predictor_outputs, embedder + ) + return all_loss_dict diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_pix2shape.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_pix2shape.py new file mode 100644 index 0000000..e173918 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_pix2shape.py @@ -0,0 +1,154 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, List +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from densepose.data.meshes.catalog import MeshCatalog +from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix + +from .embed_utils import PackedCseAnnotations +from .mask import extract_data_for_mask_loss_from_matches + + +def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor: + rows = torch.arange(grid_size) + cols = torch.arange(grid_size) + # at index `i` contains [row, col], where + # row = i // grid_size + # col = i % grid_size + pix_coords = ( + torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float() + ) + return squared_euclidean_distance_matrix(pix_coords, pix_coords) + + +def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor: + fg_mask_flattened = fg_mask.reshape((-1,)) + num_pixels = int(fg_mask_flattened.sum().item()) + fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0] + if (sample_size <= 0) or (num_pixels <= sample_size): + return fg_pixel_indices + sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size] + return fg_pixel_indices[sample_indices] + + +def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor: + fg_mask_flattened = fg_mask.reshape((-1,)) + num_pixels = int(fg_mask_flattened.sum().item()) + if (sample_size <= 0) or (num_pixels <= sample_size): + return fg_mask_flattened.nonzero(as_tuple=True)[0] + return fg_mask_flattened.float().multinomial(sample_size, replacement=False) + + +class PixToShapeCycleLoss(nn.Module): + """ + Cycle loss for pixel-vertex correspondence + """ + + def __init__(self, cfg: CfgNode): + super().__init__() + self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys()) + self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE + self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P + self.use_all_meshes_not_gt_only = ( + cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY + ) + self.num_pixels_to_sample = ( + cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE + ) + self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA + self.temperature_pix_to_vertex = ( + cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX + ) + self.temperature_vertex_to_pix = ( + cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL + ) + self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE) + + def forward( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: PackedCseAnnotations, + embedder: nn.Module, + ): + """ + Args: + proposals_with_gt (list of Instances): detections with associated + ground truth data; each item corresponds to instances detected + on 1 image; the number of items corresponds to the number of + images in a batch + densepose_predictor_outputs: an object of a dataclass that contains predictor + outputs with estimated values; assumed to have the following attributes: + * embedding - embedding estimates, tensor of shape [N, D, S, S], where + N = number of instances (= sum N_i, where N_i is the number of + instances on image i) + D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) + S = output size (width and height) + packed_annotations (PackedCseAnnotations): contains various data useful + for loss computation, each data is packed into a single tensor + embedder (nn.Module): module that computes vertex embeddings for different meshes + """ + pix_embeds = densepose_predictor_outputs.embedding + if self.pixel_dists.device != pix_embeds.device: + # should normally be done only once + self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device) + with torch.no_grad(): + mask_loss_data = extract_data_for_mask_loss_from_matches( + proposals_with_gt, densepose_predictor_outputs.coarse_segm + ) + # GT masks - tensor of shape [N, S, S] of int64 + masks_gt = mask_loss_data.masks_gt.long() # pyre-ignore[16] + assert len(pix_embeds) == len(masks_gt), ( + f"Number of instances with embeddings {len(pix_embeds)} != " + f"number of instances with GT masks {len(masks_gt)}" + ) + losses = [] + mesh_names = ( + self.shape_names + if self.use_all_meshes_not_gt_only + else [ + MeshCatalog.get_mesh_name(mesh_id.item()) + for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique() + ] + ) + for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt): + # pixel_embeddings [D, S, S] + # mask_gt [S, S] + for mesh_name in mesh_names: + mesh_vertex_embeddings = embedder(mesh_name) + # pixel indices [M] + pixel_indices_flattened = _sample_fg_pixels_randperm( + mask_gt, self.num_pixels_to_sample + ) + # pixel distances [M, M] + pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[ + torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened) + ] + # pixel embeddings [M, D] + pixel_embeddings_sampled = normalize_embeddings( + pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T + ) + # pixel-vertex similarity [M, K] + sim_matrix = pixel_embeddings_sampled.mm(mesh_vertex_embeddings.T) + c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1) + c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1) + c_cycle = c_pix_vertex.mm(c_vertex_pix) + loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p) + losses.append(loss_cycle) + + if len(losses) == 0: + return pix_embeds.sum() * 0 + return torch.stack(losses, dim=0).mean() + + def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module): + losses = [ + embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names # pyre-ignore[29] + ] + losses.append(densepose_predictor_outputs.embedding.sum() * 0) + return torch.mean(torch.stack(losses)) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_shape2shape.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_shape2shape.py new file mode 100644 index 0000000..2447e8f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/cycle_shape2shape.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import random +from typing import Tuple +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode + +from densepose.structures.mesh import create_mesh + +from .utils import sample_random_indices + + +class ShapeToShapeCycleLoss(nn.Module): + """ + Cycle Loss for Shapes. + Inspired by: + "Mapping in a Cycle: Sinkhorn Regularized Unsupervised Learning for Point Cloud Shapes". + """ + + def __init__(self, cfg: CfgNode): + super().__init__() + self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys()) + self.all_shape_pairs = [ + (x, y) for i, x in enumerate(self.shape_names) for y in self.shape_names[i + 1 :] + ] + random.shuffle(self.all_shape_pairs) + self.cur_pos = 0 + self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P + self.temperature = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE + self.max_num_vertices = ( + cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES + ) + + def _sample_random_pair(self) -> Tuple[str, str]: + """ + Produce a random pair of different mesh names + + Return: + tuple(str, str): a pair of different mesh names + """ + if self.cur_pos >= len(self.all_shape_pairs): + random.shuffle(self.all_shape_pairs) + self.cur_pos = 0 + shape_pair = self.all_shape_pairs[self.cur_pos] + self.cur_pos += 1 + return shape_pair + + def forward(self, embedder: nn.Module): + """ + Do a forward pass with a random pair (src, dst) pair of shapes + Args: + embedder (nn.Module): module that computes vertex embeddings for different meshes + """ + src_mesh_name, dst_mesh_name = self._sample_random_pair() + return self._forward_one_pair(embedder, src_mesh_name, dst_mesh_name) + + def fake_value(self, embedder: nn.Module): + losses = [] + for mesh_name in embedder.mesh_names: # pyre-ignore[29] + losses.append(embedder(mesh_name).sum() * 0) + return torch.mean(torch.stack(losses)) + + def _get_embeddings_and_geodists_for_mesh( + self, embedder: nn.Module, mesh_name: str + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Produces embeddings and geodesic distance tensors for a given mesh. May subsample + the mesh, if it contains too many vertices (controlled by + SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter). + Args: + embedder (nn.Module): module that computes embeddings for mesh vertices + mesh_name (str): mesh name + Return: + embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh + vertices (N = number of selected vertices, D = embedding space dim) + geodists (torch.Tensor of size [N, N]): geodesic distances for the selected + mesh vertices (N = number of selected vertices) + """ + embeddings = embedder(mesh_name) + indices = sample_random_indices( + embeddings.shape[0], self.max_num_vertices, embeddings.device + ) + mesh = create_mesh(mesh_name, embeddings.device) + geodists = mesh.geodists + if indices is not None: + embeddings = embeddings[indices] + geodists = geodists[torch.meshgrid(indices, indices)] + return embeddings, geodists + + def _forward_one_pair( + self, embedder: nn.Module, mesh_name_1: str, mesh_name_2: str + ) -> torch.Tensor: + """ + Do a forward pass with a selected pair of meshes + Args: + embedder (nn.Module): module that computes vertex embeddings for different meshes + mesh_name_1 (str): first mesh name + mesh_name_2 (str): second mesh name + Return: + Tensor containing the loss value + """ + embeddings_1, geodists_1 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_1) + embeddings_2, geodists_2 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_2) + sim_matrix_12 = embeddings_1.mm(embeddings_2.T) + + c_12 = F.softmax(sim_matrix_12 / self.temperature, dim=1) + c_21 = F.softmax(sim_matrix_12.T / self.temperature, dim=1) + c_11 = c_12.mm(c_21) + c_22 = c_21.mm(c_12) + + loss_cycle_11 = torch.norm(geodists_1 * c_11, p=self.norm_p) + loss_cycle_22 = torch.norm(geodists_2 * c_22, p=self.norm_p) + + return loss_cycle_11 + loss_cycle_22 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed.py new file mode 100644 index 0000000..163eebe --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, Dict, List +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from densepose.data.meshes.catalog import MeshCatalog +from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix + +from .embed_utils import PackedCseAnnotations +from .utils import BilinearInterpolationHelper + + +class EmbeddingLoss: + """ + Computes losses for estimated embeddings given annotated vertices. + Instances in a minibatch that correspond to the same mesh are grouped + together. For each group, loss is computed as cross-entropy for + unnormalized scores given ground truth mesh vertex ids. + Scores are based on squared distances between estimated vertex embeddings + and mesh vertex embeddings. + """ + + def __init__(self, cfg: CfgNode): + """ + Initialize embedding loss from config + """ + self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA + + def __call__( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: PackedCseAnnotations, + interpolator: BilinearInterpolationHelper, + embedder: nn.Module, + ) -> Dict[int, torch.Tensor]: + """ + Produces losses for estimated embeddings given annotated vertices. + Embeddings for all the vertices of a mesh are computed by the embedder. + Embeddings for observed pixels are estimated by a predictor. + Losses are computed as cross-entropy for squared distances between + observed vertex embeddings and all mesh vertex embeddings given + ground truth vertex IDs. + + Args: + proposals_with_gt (list of Instances): detections with associated + ground truth data; each item corresponds to instances detected + on 1 image; the number of items corresponds to the number of + images in a batch + densepose_predictor_outputs: an object of a dataclass that contains predictor + outputs with estimated values; assumed to have the following attributes: + * embedding - embedding estimates, tensor of shape [N, D, S, S], where + N = number of instances (= sum N_i, where N_i is the number of + instances on image i) + D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) + S = output size (width and height) + packed_annotations (PackedCseAnnotations): contains various data useful + for loss computation, each data is packed into a single tensor + interpolator (BilinearInterpolationHelper): bilinear interpolation helper + embedder (nn.Module): module that computes vertex embeddings for different meshes + Return: + dict(int -> tensor): losses for different mesh IDs + """ + losses = {} + for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique(): + mesh_id = mesh_id_tensor.item() + mesh_name = MeshCatalog.get_mesh_name(mesh_id) + # valid points are those that fall into estimated bbox + # and correspond to the current mesh + j_valid = interpolator.j_valid * ( # pyre-ignore[16] + packed_annotations.vertex_mesh_ids_gt == mesh_id + ) + if not torch.any(j_valid): + continue + # extract estimated embeddings for valid points + # -> tensor [J, D] + vertex_embeddings_i = normalize_embeddings( + interpolator.extract_at_points( + densepose_predictor_outputs.embedding, + slice_fine_segm=slice(None), + w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16] + w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16] + w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16] + w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16] + )[j_valid, :] + ) + # extract vertex ids for valid points + # -> tensor [J] + vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid] + # embeddings for all mesh vertices + # -> tensor [K, D] + mesh_vertex_embeddings = embedder(mesh_name) + # unnormalized scores for valid points + # -> tensor [J, K] + scores = squared_euclidean_distance_matrix( + vertex_embeddings_i, mesh_vertex_embeddings + ) / (-self.embdist_gauss_sigma) + losses[mesh_name] = F.cross_entropy(scores, vertex_indices_i, ignore_index=-1) + + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], nn.Module, + # torch.Tensor]` is not a function. + for mesh_name in embedder.mesh_names: + if mesh_name not in losses: + losses[mesh_name] = self.fake_value( + densepose_predictor_outputs, embedder, mesh_name + ) + return losses + + def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module): + losses = {} + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], nn.Module, + # torch.Tensor]` is not a function. + for mesh_name in embedder.mesh_names: + losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name) + return losses + + def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str): + return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed_utils.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed_utils.py new file mode 100644 index 0000000..eb9492f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/embed_utils.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from dataclasses import dataclass +from typing import Any, Optional +import torch + +from detectron2.structures import BoxMode, Instances + +from .utils import AnnotationsAccumulator + + +@dataclass +class PackedCseAnnotations: + x_gt: torch.Tensor + y_gt: torch.Tensor + coarse_segm_gt: Optional[torch.Tensor] + vertex_mesh_ids_gt: torch.Tensor + vertex_ids_gt: torch.Tensor + bbox_xywh_gt: torch.Tensor + bbox_xywh_est: torch.Tensor + point_bbox_with_dp_indices: torch.Tensor + point_bbox_indices: torch.Tensor + bbox_indices: torch.Tensor + + +class CseAnnotationsAccumulator(AnnotationsAccumulator): + """ + Accumulates annotations by batches that correspond to objects detected on + individual images. Can pack them together into single tensors. + """ + + def __init__(self): + self.x_gt = [] + self.y_gt = [] + self.s_gt = [] + self.vertex_mesh_ids_gt = [] + self.vertex_ids_gt = [] + self.bbox_xywh_gt = [] + self.bbox_xywh_est = [] + self.point_bbox_with_dp_indices = [] + self.point_bbox_indices = [] + self.bbox_indices = [] + self.nxt_bbox_with_dp_index = 0 + self.nxt_bbox_index = 0 + + def accumulate(self, instances_one_image: Instances): + """ + Accumulate instances data for one image + + Args: + instances_one_image (Instances): instances data to accumulate + """ + boxes_xywh_est = BoxMode.convert( + instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + boxes_xywh_gt = BoxMode.convert( + instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + n_matches = len(boxes_xywh_gt) + assert n_matches == len( + boxes_xywh_est + ), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes" + if not n_matches: + # no detection - GT matches + return + if ( + not hasattr(instances_one_image, "gt_densepose") + or instances_one_image.gt_densepose is None + ): + # no densepose GT for the detections, just increase the bbox index + self.nxt_bbox_index += n_matches + return + for box_xywh_est, box_xywh_gt, dp_gt in zip( + boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose + ): + if (dp_gt is not None) and (len(dp_gt.x) > 0): + self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt) + self.nxt_bbox_index += 1 + + def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any): + """ + Accumulate instances data for one image, given that the data is not empty + + Args: + box_xywh_gt (tensor): GT bounding box + box_xywh_est (tensor): estimated bounding box + dp_gt: GT densepose data with the following attributes: + - x: normalized X coordinates + - y: normalized Y coordinates + - segm: tensor of size [S, S] with coarse segmentation + - + """ + self.x_gt.append(dp_gt.x) + self.y_gt.append(dp_gt.y) + if hasattr(dp_gt, "segm"): + self.s_gt.append(dp_gt.segm.unsqueeze(0)) + self.vertex_ids_gt.append(dp_gt.vertex_ids) + self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id)) + self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4)) + self.bbox_xywh_est.append(box_xywh_est.view(-1, 4)) + self.point_bbox_with_dp_indices.append( + torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index) + ) + self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index)) + self.bbox_indices.append(self.nxt_bbox_index) + self.nxt_bbox_with_dp_index += 1 + + def pack(self) -> Optional[PackedCseAnnotations]: + """ + Pack data into tensors + """ + if not len(self.x_gt): + # TODO: + # returning proper empty annotations would require + # creating empty tensors of appropriate shape and + # type on an appropriate device; + # we return None so far to indicate empty annotations + return None + return PackedCseAnnotations( + x_gt=torch.cat(self.x_gt, 0), + y_gt=torch.cat(self.y_gt, 0), + vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0), + vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0), + # ignore segmentation annotations, if not all the instances contain those + coarse_segm_gt=torch.cat(self.s_gt, 0) + if len(self.s_gt) == len(self.bbox_xywh_gt) + else None, + bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0), + bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0), + point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0), + point_bbox_indices=torch.cat(self.point_bbox_indices, 0), + bbox_indices=torch.as_tensor( + self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device + ), + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask.py new file mode 100644 index 0000000..c16b15c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional +import torch +from torch.nn import functional as F + +from detectron2.structures import Instances + + +@dataclass +class DataForMaskLoss: + """ + Contains mask GT and estimated data for proposals from multiple images: + """ + + # tensor of size (K, H, W) containing GT labels + masks_gt: Optional[torch.Tensor] = None + # tensor of size (K, C, H, W) containing estimated scores + masks_est: Optional[torch.Tensor] = None + + +def extract_data_for_mask_loss_from_matches( + proposals_targets: Iterable[Instances], estimated_segm: torch.Tensor +) -> DataForMaskLoss: + """ + Extract data for mask loss from instances that contain matched GT and + estimated bounding boxes. + Args: + proposals_targets: Iterable[Instances] + matched GT and estimated results, each item in the iterable + corresponds to data in 1 image + estimated_segm: tensor(K, C, S, S) of float - raw unnormalized + segmentation scores, here S is the size to which GT masks are + to be resized + Return: + masks_est: tensor(K, C, S, S) of float - class scores + masks_gt: tensor(K, S, S) of int64 - labels + """ + data = DataForMaskLoss() + masks_gt = [] + offset = 0 + assert estimated_segm.shape[2] == estimated_segm.shape[3], ( + f"Expected estimated segmentation to have a square shape, " + f"but the actual shape is {estimated_segm.shape[2:]}" + ) + mask_size = estimated_segm.shape[2] + num_proposals = sum(inst.proposal_boxes.tensor.size(0) for inst in proposals_targets) + num_estimated = estimated_segm.shape[0] + assert ( + num_proposals == num_estimated + ), "The number of proposals {} must be equal to the number of estimates {}".format( + num_proposals, num_estimated + ) + + for proposals_targets_per_image in proposals_targets: + n_i = proposals_targets_per_image.proposal_boxes.tensor.size(0) + if not n_i: + continue + gt_masks_per_image = proposals_targets_per_image.gt_masks.crop_and_resize( + proposals_targets_per_image.proposal_boxes.tensor, mask_size + ).to(device=estimated_segm.device) + masks_gt.append(gt_masks_per_image) + offset += n_i + if masks_gt: + data.masks_est = estimated_segm + data.masks_gt = torch.cat(masks_gt, dim=0) + return data + + +class MaskLoss: + """ + Mask loss as cross-entropy for raw unnormalized scores given ground truth labels. + Mask ground truth labels are defined for the whole image and not only the + bounding box of interest. They are stored as objects that are assumed to implement + the `crop_and_resize` interface (e.g. BitMasks, PolygonMasks). + """ + + def __call__( + self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any + ) -> torch.Tensor: + """ + Computes segmentation loss as cross-entropy for raw unnormalized + scores given ground truth labels. + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: an object of a dataclass that contains predictor outputs + with estimated values; assumed to have the following attribute: + * coarse_segm (tensor of shape [N, D, S, S]): coarse segmentation estimates + as raw unnormalized scores + where N is the number of detections, S is the estimate size ( = width = height) + and D is the number of coarse segmentation channels. + Return: + Cross entropy for raw unnormalized scores for coarse segmentation given + ground truth labels from masks + """ + if not len(proposals_with_gt): + return self.fake_value(densepose_predictor_outputs) + # densepose outputs are computed for all images and all bounding boxes; + # i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively, + # the outputs will have size(0) == 3+1+2+1 == 7 + with torch.no_grad(): + mask_loss_data = extract_data_for_mask_loss_from_matches( + proposals_with_gt, densepose_predictor_outputs.coarse_segm + ) + if (mask_loss_data.masks_gt is None) or (mask_loss_data.masks_est is None): + return self.fake_value(densepose_predictor_outputs) + return F.cross_entropy(mask_loss_data.masks_est, mask_loss_data.masks_gt.long()) + + def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor: + """ + Fake segmentation loss used when no suitable ground truth data + was found in a batch. The loss has a value 0 and is primarily used to + construct the computation graph, so that `DistributedDataParallel` + has similar graphs on all GPUs and can perform reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have `coarse_segm` + attribute + Return: + Zero value loss with proper computation graph + """ + return densepose_predictor_outputs.coarse_segm.sum() * 0 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask_or_segm.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask_or_segm.py new file mode 100644 index 0000000..98b773d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/mask_or_segm.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, List +import torch + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from .mask import MaskLoss +from .segm import SegmentationLoss + + +class MaskOrSegmentationLoss: + """ + Mask or segmentation loss as cross-entropy for raw unnormalized scores + given ground truth labels. Ground truth labels are either defined by coarse + segmentation annotation, or by mask annotation, depending on the config + value MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS + """ + + def __init__(self, cfg: CfgNode): + """ + Initialize segmentation loss from configuration options + + Args: + cfg (CfgNode): configuration options + """ + self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS + if self.segm_trained_by_masks: + self.mask_loss = MaskLoss() + self.segm_loss = SegmentationLoss(cfg) + + def __call__( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: Any, + ) -> torch.Tensor: + """ + Compute segmentation loss as cross-entropy between aligned unnormalized + score estimates and ground truth; with ground truth given + either by masks, or by coarse segmentation annotations. + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: an object of a dataclass that contains predictor outputs + with estimated values; assumed to have the following attributes: + * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] + packed_annotations: packed annotations for efficient loss computation + Return: + tensor: loss value as cross-entropy for raw unnormalized scores + given ground truth labels + """ + if self.segm_trained_by_masks: + return self.mask_loss(proposals_with_gt, densepose_predictor_outputs) + return self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations) + + def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor: + """ + Fake segmentation loss used when no suitable ground truth data + was found in a batch. The loss has a value 0 and is primarily used to + construct the computation graph, so that `DistributedDataParallel` + has similar graphs on all GPUs and can perform reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have `coarse_segm` + attribute + Return: + Zero value loss with proper computation graph + """ + return densepose_predictor_outputs.coarse_segm.sum() * 0 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/registry.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/registry.py new file mode 100644 index 0000000..d9c8817 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/registry.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.utils.registry import Registry + +DENSEPOSE_LOSS_REGISTRY = Registry("DENSEPOSE_LOSS") diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/segm.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/segm.py new file mode 100644 index 0000000..1962b88 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/segm.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, List +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from .utils import resample_data + + +class SegmentationLoss: + """ + Segmentation loss as cross-entropy for raw unnormalized scores given ground truth + labels. Segmentation ground truth labels are defined for the bounding box of + interest at some fixed resolution [S, S], where + S = MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE. + """ + + def __init__(self, cfg: CfgNode): + """ + Initialize segmentation loss from configuration options + + Args: + cfg (CfgNode): configuration options + """ + self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE + self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + + def __call__( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: Any, + ) -> torch.Tensor: + """ + Compute segmentation loss as cross-entropy on aligned segmentation + ground truth and estimated scores. + + Args: + proposals_with_gt (list of Instances): detections with associated ground truth data + densepose_predictor_outputs: an object of a dataclass that contains predictor outputs + with estimated values; assumed to have the following attributes: + * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] + packed_annotations: packed annotations for efficient loss computation; + the following attributes are used: + - coarse_segm_gt + - bbox_xywh_gt + - bbox_xywh_est + """ + if packed_annotations.coarse_segm_gt is None: + return self.fake_value(densepose_predictor_outputs) + coarse_segm_est = densepose_predictor_outputs.coarse_segm[packed_annotations.bbox_indices] + with torch.no_grad(): + coarse_segm_gt = resample_data( + packed_annotations.coarse_segm_gt.unsqueeze(1), + packed_annotations.bbox_xywh_gt, + packed_annotations.bbox_xywh_est, + self.heatmap_size, + self.heatmap_size, + mode="nearest", + padding_mode="zeros", + ).squeeze(1) + if self.n_segm_chan == 2: + coarse_segm_gt = coarse_segm_gt > 0 + return F.cross_entropy(coarse_segm_est, coarse_segm_gt.long()) + + def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor: + """ + Fake segmentation loss used when no suitable ground truth data + was found in a batch. The loss has a value 0 and is primarily used to + construct the computation graph, so that `DistributedDataParallel` + has similar graphs on all GPUs and can perform reduction properly. + + Args: + densepose_predictor_outputs: DensePose predictor outputs, an object + of a dataclass that is assumed to have `coarse_segm` + attribute + Return: + Zero value loss with proper computation graph + """ + return densepose_predictor_outputs.coarse_segm.sum() * 0 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/soft_embed.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/soft_embed.py new file mode 100644 index 0000000..176d929 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/soft_embed.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, Dict, List +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.structures import Instances + +from densepose.data.meshes.catalog import MeshCatalog +from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix +from densepose.structures.mesh import create_mesh + +from .embed_utils import PackedCseAnnotations +from .utils import BilinearInterpolationHelper + + +class SoftEmbeddingLoss: + """ + Computes losses for estimated embeddings given annotated vertices. + Instances in a minibatch that correspond to the same mesh are grouped + together. For each group, loss is computed as cross-entropy for + unnormalized scores given ground truth mesh vertex ids. + Scores are based on: + 1) squared distances between estimated vertex embeddings + and mesh vertex embeddings; + 2) geodesic distances between vertices of a mesh + """ + + def __init__(self, cfg: CfgNode): + """ + Initialize embedding loss from config + """ + self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA + self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA + + def __call__( + self, + proposals_with_gt: List[Instances], + densepose_predictor_outputs: Any, + packed_annotations: PackedCseAnnotations, + interpolator: BilinearInterpolationHelper, + embedder: nn.Module, + ) -> Dict[int, torch.Tensor]: + """ + Produces losses for estimated embeddings given annotated vertices. + Embeddings for all the vertices of a mesh are computed by the embedder. + Embeddings for observed pixels are estimated by a predictor. + Losses are computed as cross-entropy for unnormalized scores given + ground truth vertex IDs. + 1) squared distances between estimated vertex embeddings + and mesh vertex embeddings; + 2) geodesic distances between vertices of a mesh + + Args: + proposals_with_gt (list of Instances): detections with associated + ground truth data; each item corresponds to instances detected + on 1 image; the number of items corresponds to the number of + images in a batch + densepose_predictor_outputs: an object of a dataclass that contains predictor + outputs with estimated values; assumed to have the following attributes: + * embedding - embedding estimates, tensor of shape [N, D, S, S], where + N = number of instances (= sum N_i, where N_i is the number of + instances on image i) + D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) + S = output size (width and height) + packed_annotations (PackedCseAnnotations): contains various data useful + for loss computation, each data is packed into a single tensor + interpolator (BilinearInterpolationHelper): bilinear interpolation helper + embedder (nn.Module): module that computes vertex embeddings for different meshes + Return: + dict(int -> tensor): losses for different mesh IDs + """ + losses = {} + for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique(): + mesh_id = mesh_id_tensor.item() + mesh_name = MeshCatalog.get_mesh_name(mesh_id) + # valid points are those that fall into estimated bbox + # and correspond to the current mesh + j_valid = interpolator.j_valid * ( # pyre-ignore[16] + packed_annotations.vertex_mesh_ids_gt == mesh_id + ) + if not torch.any(j_valid): + continue + # extract estimated embeddings for valid points + # -> tensor [J, D] + vertex_embeddings_i = normalize_embeddings( + interpolator.extract_at_points( + densepose_predictor_outputs.embedding, + slice_fine_segm=slice(None), + w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16] + w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16] + w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16] + w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16] + )[j_valid, :] + ) + # extract vertex ids for valid points + # -> tensor [J] + vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid] + # embeddings for all mesh vertices + # -> tensor [K, D] + mesh_vertex_embeddings = embedder(mesh_name) + # softmax values of geodesic distances for GT mesh vertices + # -> tensor [J, K] + mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device) + geodist_softmax_values = F.softmax( + mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1 + ) + # logsoftmax values for valid points + # -> tensor [J, K] + embdist_logsoftmax_values = F.log_softmax( + squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings) + / (-self.embdist_gauss_sigma), + dim=1, + ) + losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean() + + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], nn.Module, + # torch.Tensor]` is not a function. + for mesh_name in embedder.mesh_names: + if mesh_name not in losses: + losses[mesh_name] = self.fake_value( + densepose_predictor_outputs, embedder, mesh_name + ) + return losses + + def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module): + losses = {} + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], nn.Module, + # torch.Tensor]` is not a function. + for mesh_name in embedder.mesh_names: + losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name) + return losses + + def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str): + return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0 diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/utils.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/utils.py new file mode 100644 index 0000000..4c172ae --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/losses/utils.py @@ -0,0 +1,441 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple +import torch +from torch.nn import functional as F + +from detectron2.structures import BoxMode, Instances + +from densepose import DensePoseDataRelative + +LossDict = Dict[str, torch.Tensor] + + +def _linear_interpolation_utilities(v_norm, v0_src, size_src, v0_dst, size_dst, size_z): + """ + Computes utility values for linear interpolation at points v. + The points are given as normalized offsets in the source interval + (v0_src, v0_src + size_src), more precisely: + v = v0_src + v_norm * size_src / 256.0 + The computed utilities include lower points v_lo, upper points v_hi, + interpolation weights v_w and flags j_valid indicating whether the + points falls into the destination interval (v0_dst, v0_dst + size_dst). + + Args: + v_norm (:obj: `torch.Tensor`): tensor of size N containing + normalized point offsets + v0_src (:obj: `torch.Tensor`): tensor of size N containing + left bounds of source intervals for normalized points + size_src (:obj: `torch.Tensor`): tensor of size N containing + source interval sizes for normalized points + v0_dst (:obj: `torch.Tensor`): tensor of size N containing + left bounds of destination intervals + size_dst (:obj: `torch.Tensor`): tensor of size N containing + destination interval sizes + size_z (int): interval size for data to be interpolated + + Returns: + v_lo (:obj: `torch.Tensor`): int tensor of size N containing + indices of lower values used for interpolation, all values are + integers from [0, size_z - 1] + v_hi (:obj: `torch.Tensor`): int tensor of size N containing + indices of upper values used for interpolation, all values are + integers from [0, size_z - 1] + v_w (:obj: `torch.Tensor`): float tensor of size N containing + interpolation weights + j_valid (:obj: `torch.Tensor`): uint8 tensor of size N containing + 0 for points outside the estimation interval + (v0_est, v0_est + size_est) and 1 otherwise + """ + v = v0_src + v_norm * size_src / 256.0 + j_valid = (v - v0_dst >= 0) * (v - v0_dst < size_dst) + v_grid = (v - v0_dst) * size_z / size_dst + v_lo = v_grid.floor().long().clamp(min=0, max=size_z - 1) + v_hi = (v_lo + 1).clamp(max=size_z - 1) + v_grid = torch.min(v_hi.float(), v_grid) + v_w = v_grid - v_lo.float() + return v_lo, v_hi, v_w, j_valid + + +class BilinearInterpolationHelper: + """ + Args: + packed_annotations: object that contains packed annotations + j_valid (:obj: `torch.Tensor`): uint8 tensor of size M containing + 0 for points to be discarded and 1 for points to be selected + y_lo (:obj: `torch.Tensor`): int tensor of indices of upper values + in z_est for each point + y_hi (:obj: `torch.Tensor`): int tensor of indices of lower values + in z_est for each point + x_lo (:obj: `torch.Tensor`): int tensor of indices of left values + in z_est for each point + x_hi (:obj: `torch.Tensor`): int tensor of indices of right values + in z_est for each point + w_ylo_xlo (:obj: `torch.Tensor`): float tensor of size M; + contains upper-left value weight for each point + w_ylo_xhi (:obj: `torch.Tensor`): float tensor of size M; + contains upper-right value weight for each point + w_yhi_xlo (:obj: `torch.Tensor`): float tensor of size M; + contains lower-left value weight for each point + w_yhi_xhi (:obj: `torch.Tensor`): float tensor of size M; + contains lower-right value weight for each point + """ + + def __init__( + self, + packed_annotations: Any, + j_valid: torch.Tensor, + y_lo: torch.Tensor, + y_hi: torch.Tensor, + x_lo: torch.Tensor, + x_hi: torch.Tensor, + w_ylo_xlo: torch.Tensor, + w_ylo_xhi: torch.Tensor, + w_yhi_xlo: torch.Tensor, + w_yhi_xhi: torch.Tensor, + ): + for k, v in locals().items(): + if k != "self": + setattr(self, k, v) + + @staticmethod + def from_matches( + packed_annotations: Any, densepose_outputs_size_hw: Tuple[int, int] + ) -> "BilinearInterpolationHelper": + """ + Args: + packed_annotations: annotations packed into tensors, the following + attributes are required: + - bbox_xywh_gt + - bbox_xywh_est + - x_gt + - y_gt + - point_bbox_with_dp_indices + - point_bbox_indices + densepose_outputs_size_hw (tuple [int, int]): resolution of + DensePose predictor outputs (H, W) + Return: + An instance of `BilinearInterpolationHelper` used to perform + interpolation for the given annotation points and output resolution + """ + + zh, zw = densepose_outputs_size_hw + x0_gt, y0_gt, w_gt, h_gt = packed_annotations.bbox_xywh_gt[ + packed_annotations.point_bbox_with_dp_indices + ].unbind(dim=1) + x0_est, y0_est, w_est, h_est = packed_annotations.bbox_xywh_est[ + packed_annotations.point_bbox_with_dp_indices + ].unbind(dim=1) + x_lo, x_hi, x_w, jx_valid = _linear_interpolation_utilities( + packed_annotations.x_gt, x0_gt, w_gt, x0_est, w_est, zw + ) + y_lo, y_hi, y_w, jy_valid = _linear_interpolation_utilities( + packed_annotations.y_gt, y0_gt, h_gt, y0_est, h_est, zh + ) + j_valid = jx_valid * jy_valid + + w_ylo_xlo = (1.0 - x_w) * (1.0 - y_w) + w_ylo_xhi = x_w * (1.0 - y_w) + w_yhi_xlo = (1.0 - x_w) * y_w + w_yhi_xhi = x_w * y_w + + return BilinearInterpolationHelper( + packed_annotations, + j_valid, + y_lo, + y_hi, + x_lo, + x_hi, + w_ylo_xlo, # pyre-ignore[6] + w_ylo_xhi, + # pyre-fixme[6]: Expected `Tensor` for 9th param but got `float`. + w_yhi_xlo, + w_yhi_xhi, + ) + + def extract_at_points( + self, + z_est, + slice_fine_segm=None, + w_ylo_xlo=None, + w_ylo_xhi=None, + w_yhi_xlo=None, + w_yhi_xhi=None, + ): + """ + Extract ground truth values z_gt for valid point indices and estimated + values z_est using bilinear interpolation over top-left (y_lo, x_lo), + top-right (y_lo, x_hi), bottom-left (y_hi, x_lo) and bottom-right + (y_hi, x_hi) values in z_est with corresponding weights: + w_ylo_xlo, w_ylo_xhi, w_yhi_xlo and w_yhi_xhi. + Use slice_fine_segm to slice dim=1 in z_est + """ + slice_fine_segm = ( + self.packed_annotations.fine_segm_labels_gt + if slice_fine_segm is None + else slice_fine_segm + ) + w_ylo_xlo = self.w_ylo_xlo if w_ylo_xlo is None else w_ylo_xlo + w_ylo_xhi = self.w_ylo_xhi if w_ylo_xhi is None else w_ylo_xhi + w_yhi_xlo = self.w_yhi_xlo if w_yhi_xlo is None else w_yhi_xlo + w_yhi_xhi = self.w_yhi_xhi if w_yhi_xhi is None else w_yhi_xhi + + index_bbox = self.packed_annotations.point_bbox_indices + z_est_sampled = ( + z_est[index_bbox, slice_fine_segm, self.y_lo, self.x_lo] * w_ylo_xlo + + z_est[index_bbox, slice_fine_segm, self.y_lo, self.x_hi] * w_ylo_xhi + + z_est[index_bbox, slice_fine_segm, self.y_hi, self.x_lo] * w_yhi_xlo + + z_est[index_bbox, slice_fine_segm, self.y_hi, self.x_hi] * w_yhi_xhi + ) + return z_est_sampled + + +def resample_data( + z, bbox_xywh_src, bbox_xywh_dst, wout, hout, mode: str = "nearest", padding_mode: str = "zeros" +): + """ + Args: + z (:obj: `torch.Tensor`): tensor of size (N,C,H,W) with data to be + resampled + bbox_xywh_src (:obj: `torch.Tensor`): tensor of size (N,4) containing + source bounding boxes in format XYWH + bbox_xywh_dst (:obj: `torch.Tensor`): tensor of size (N,4) containing + destination bounding boxes in format XYWH + Return: + zresampled (:obj: `torch.Tensor`): tensor of size (N, C, Hout, Wout) + with resampled values of z, where D is the discretization size + """ + n = bbox_xywh_src.size(0) + assert n == bbox_xywh_dst.size(0), ( + "The number of " + "source ROIs for resampling ({}) should be equal to the number " + "of destination ROIs ({})".format(bbox_xywh_src.size(0), bbox_xywh_dst.size(0)) + ) + x0src, y0src, wsrc, hsrc = bbox_xywh_src.unbind(dim=1) + x0dst, y0dst, wdst, hdst = bbox_xywh_dst.unbind(dim=1) + x0dst_norm = 2 * (x0dst - x0src) / wsrc - 1 + y0dst_norm = 2 * (y0dst - y0src) / hsrc - 1 + x1dst_norm = 2 * (x0dst + wdst - x0src) / wsrc - 1 + y1dst_norm = 2 * (y0dst + hdst - y0src) / hsrc - 1 + grid_w = torch.arange(wout, device=z.device, dtype=torch.float) / wout + grid_h = torch.arange(hout, device=z.device, dtype=torch.float) / hout + grid_w_expanded = grid_w[None, None, :].expand(n, hout, wout) + grid_h_expanded = grid_h[None, :, None].expand(n, hout, wout) + dx_expanded = (x1dst_norm - x0dst_norm)[:, None, None].expand(n, hout, wout) + dy_expanded = (y1dst_norm - y0dst_norm)[:, None, None].expand(n, hout, wout) + x0_expanded = x0dst_norm[:, None, None].expand(n, hout, wout) + y0_expanded = y0dst_norm[:, None, None].expand(n, hout, wout) + grid_x = grid_w_expanded * dx_expanded + x0_expanded + grid_y = grid_h_expanded * dy_expanded + y0_expanded + grid = torch.stack((grid_x, grid_y), dim=3) + # resample Z from (N, C, H, W) into (N, C, Hout, Wout) + zresampled = F.grid_sample(z, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + return zresampled + + +class AnnotationsAccumulator(ABC): + """ + Abstract class for an accumulator for annotations that can produce + dense annotations packed into tensors. + """ + + @abstractmethod + def accumulate(self, instances_one_image: Instances): + """ + Accumulate instances data for one image + + Args: + instances_one_image (Instances): instances data to accumulate + """ + pass + + @abstractmethod + def pack(self) -> Any: + """ + Pack data into tensors + """ + pass + + +@dataclass +class PackedChartBasedAnnotations: + """ + Packed annotations for chart-based model training. The following attributes + are defined: + - fine_segm_labels_gt (tensor [K] of `int64`): GT fine segmentation point labels + - x_gt (tensor [K] of `float32`): GT normalized X point coordinates + - y_gt (tensor [K] of `float32`): GT normalized Y point coordinates + - u_gt (tensor [K] of `float32`): GT point U values + - v_gt (tensor [K] of `float32`): GT point V values + - coarse_segm_gt (tensor [N, S, S] of `float32`): GT segmentation for bounding boxes + - bbox_xywh_gt (tensor [N, 4] of `float32`): selected GT bounding boxes in + XYWH format + - bbox_xywh_est (tensor [N, 4] of `float32`): selected matching estimated + bounding boxes in XYWH format + - point_bbox_with_dp_indices (tensor [K] of `int64`): indices of bounding boxes + with DensePose annotations that correspond to the point data + - point_bbox_indices (tensor [K] of `int64`): indices of bounding boxes + (not necessarily the selected ones with DensePose data) that correspond + to the point data + - bbox_indices (tensor [N] of `int64`): global indices of selected bounding + boxes with DensePose annotations; these indices could be used to access + features that are computed for all bounding boxes, not only the ones with + DensePose annotations. + Here K is the total number of points and N is the total number of instances + with DensePose annotations. + """ + + fine_segm_labels_gt: torch.Tensor + x_gt: torch.Tensor + y_gt: torch.Tensor + u_gt: torch.Tensor + v_gt: torch.Tensor + coarse_segm_gt: Optional[torch.Tensor] + bbox_xywh_gt: torch.Tensor + bbox_xywh_est: torch.Tensor + point_bbox_with_dp_indices: torch.Tensor + point_bbox_indices: torch.Tensor + bbox_indices: torch.Tensor + + +class ChartBasedAnnotationsAccumulator(AnnotationsAccumulator): + """ + Accumulates annotations by batches that correspond to objects detected on + individual images. Can pack them together into single tensors. + """ + + def __init__(self): + self.i_gt = [] + self.x_gt = [] + self.y_gt = [] + self.u_gt = [] + self.v_gt = [] + self.s_gt = [] + self.bbox_xywh_gt = [] + self.bbox_xywh_est = [] + self.point_bbox_with_dp_indices = [] + self.point_bbox_indices = [] + self.bbox_indices = [] + self.nxt_bbox_with_dp_index = 0 + self.nxt_bbox_index = 0 + + def accumulate(self, instances_one_image: Instances): + """ + Accumulate instances data for one image + + Args: + instances_one_image (Instances): instances data to accumulate + """ + boxes_xywh_est = BoxMode.convert( + instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + boxes_xywh_gt = BoxMode.convert( + instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + n_matches = len(boxes_xywh_gt) + assert n_matches == len( + boxes_xywh_est + ), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes" + if not n_matches: + # no detection - GT matches + return + if ( + not hasattr(instances_one_image, "gt_densepose") + or instances_one_image.gt_densepose is None + ): + # no densepose GT for the detections, just increase the bbox index + self.nxt_bbox_index += n_matches + return + for box_xywh_est, box_xywh_gt, dp_gt in zip( + boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose + ): + if (dp_gt is not None) and (len(dp_gt.x) > 0): + self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt) + self.nxt_bbox_index += 1 + + def _do_accumulate( + self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: DensePoseDataRelative + ): + """ + Accumulate instances data for one image, given that the data is not empty + + Args: + box_xywh_gt (tensor): GT bounding box + box_xywh_est (tensor): estimated bounding box + dp_gt (DensePoseDataRelative): GT densepose data + """ + self.i_gt.append(dp_gt.i) + self.x_gt.append(dp_gt.x) + self.y_gt.append(dp_gt.y) + self.u_gt.append(dp_gt.u) + self.v_gt.append(dp_gt.v) + if hasattr(dp_gt, "segm"): + self.s_gt.append(dp_gt.segm.unsqueeze(0)) + self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4)) + self.bbox_xywh_est.append(box_xywh_est.view(-1, 4)) + self.point_bbox_with_dp_indices.append( + torch.full_like(dp_gt.i, self.nxt_bbox_with_dp_index) + ) + self.point_bbox_indices.append(torch.full_like(dp_gt.i, self.nxt_bbox_index)) + self.bbox_indices.append(self.nxt_bbox_index) + self.nxt_bbox_with_dp_index += 1 + + def pack(self) -> Optional[PackedChartBasedAnnotations]: + """ + Pack data into tensors + """ + if not len(self.i_gt): + # TODO: + # returning proper empty annotations would require + # creating empty tensors of appropriate shape and + # type on an appropriate device; + # we return None so far to indicate empty annotations + return None + return PackedChartBasedAnnotations( + fine_segm_labels_gt=torch.cat(self.i_gt, 0).long(), + x_gt=torch.cat(self.x_gt, 0), + y_gt=torch.cat(self.y_gt, 0), + u_gt=torch.cat(self.u_gt, 0), + v_gt=torch.cat(self.v_gt, 0), + # ignore segmentation annotations, if not all the instances contain those + coarse_segm_gt=torch.cat(self.s_gt, 0) + if len(self.s_gt) == len(self.bbox_xywh_gt) + else None, + bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0), + bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0), + point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0).long(), + point_bbox_indices=torch.cat(self.point_bbox_indices, 0).long(), + bbox_indices=torch.as_tensor( + self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device + ).long(), + ) + + +def extract_packed_annotations_from_matches( + proposals_with_targets: List[Instances], accumulator: AnnotationsAccumulator +) -> Any: + for proposals_targets_per_image in proposals_with_targets: + accumulator.accumulate(proposals_targets_per_image) + return accumulator.pack() + + +def sample_random_indices( + n_indices: int, n_samples: int, device: Optional[torch.device] = None +) -> Optional[torch.Tensor]: + """ + Samples `n_samples` random indices from range `[0..n_indices - 1]`. + If `n_indices` is smaller than `n_samples`, returns `None` meaning that all indices + are selected. + Args: + n_indices (int): total number of indices + n_samples (int): number of indices to sample + device (torch.device): the desired device of returned tensor + Return: + Tensor of selected vertex indices, or `None`, if all vertices are selected + """ + if (n_samples <= 0) or (n_indices <= n_samples): + return None + indices = torch.randperm(n_indices, device=device)[:n_samples] + return indices diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/__init__.py new file mode 100644 index 0000000..1ece075 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .chart import DensePoseChartPredictor +from .chart_confidence import DensePoseChartConfidencePredictorMixin +from .chart_with_confidence import DensePoseChartWithConfidencePredictor +from .cse import DensePoseEmbeddingPredictor +from .cse_confidence import DensePoseEmbeddingConfidencePredictorMixin +from .cse_with_confidence import DensePoseEmbeddingWithConfidencePredictor +from .registry import DENSEPOSE_PREDICTOR_REGISTRY diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart.py new file mode 100644 index 0000000..3bcd13f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +from torch import nn + +from detectron2.config import CfgNode +from detectron2.layers import ConvTranspose2d, interpolate + +from ...structures import DensePoseChartPredictorOutput +from ..utils import initialize_module_params +from .registry import DENSEPOSE_PREDICTOR_REGISTRY + + +@DENSEPOSE_PREDICTOR_REGISTRY.register() +class DensePoseChartPredictor(nn.Module): + """ + Predictor (last layers of a DensePose model) that takes DensePose head outputs as an input + and produces 4 tensors which represent DensePose results for predefined body parts + (patches / charts): + * coarse segmentation, a tensor of shape [N, K, Hout, Wout] + * fine segmentation, a tensor of shape [N, C, Hout, Wout] + * U coordinates, a tensor of shape [N, C, Hout, Wout] + * V coordinates, a tensor of shape [N, C, Hout, Wout] + where + - N is the number of instances + - K is the number of coarse segmentation channels ( + 2 = foreground / background, + 15 = one of 14 body parts / background) + - C is the number of fine segmentation channels ( + 24 fine body parts / background) + - Hout and Wout are height and width of predictions + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + """ + Initialize predictor using configuration options + + Args: + cfg (CfgNode): configuration options + input_channels (int): input tensor size along the channel dimension + """ + super().__init__() + dim_in = input_channels + n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1 + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL + # coarse segmentation + self.ann_index_lowres = ConvTranspose2d( + dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + # fine segmentation + self.index_uv_lowres = ConvTranspose2d( + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + # U + self.u_lowres = ConvTranspose2d( + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + # V + self.v_lowres = ConvTranspose2d( + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE + initialize_module_params(self) + + def interp2d(self, tensor_nchw: torch.Tensor): + """ + Bilinear interpolation method to be used for upscaling + + Args: + tensor_nchw (tensor): tensor of shape (N, C, H, W) + Return: + tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed + by applying the scale factor to H and W + """ + return interpolate( + tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + + def forward(self, head_outputs: torch.Tensor): + """ + Perform forward step on DensePose head outputs + + Args: + head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W] + Return: + An instance of DensePoseChartPredictorOutput + """ + return DensePoseChartPredictorOutput( + coarse_segm=self.interp2d(self.ann_index_lowres(head_outputs)), + fine_segm=self.interp2d(self.index_uv_lowres(head_outputs)), + u=self.interp2d(self.u_lowres(head_outputs)), + v=self.interp2d(self.v_lowres(head_outputs)), + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_confidence.py new file mode 100644 index 0000000..0c00999 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_confidence.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.layers import ConvTranspose2d + +from ...structures import decorate_predictor_output_class_with_confidences +from ..confidence import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType +from ..utils import initialize_module_params + + +class DensePoseChartConfidencePredictorMixin: + """ + Predictor contains the last layers of a DensePose model that take DensePose head + outputs as an input and produce model outputs. Confidence predictor mixin is used + to generate confidences for segmentation and UV tensors estimated by some + base predictor. Several assumptions need to hold for the base predictor: + 1) the `forward` method must return SIUV tuple as the first result ( + S = coarse segmentation, I = fine segmentation, U and V are intrinsic + chart coordinates) + 2) `interp2d` method must be defined to perform bilinear interpolation; + the same method is typically used for SIUV and confidences + Confidence predictor mixin provides confidence estimates, as described in: + N. Neverova et al., Correlated Uncertainty for Learning Dense Correspondences + from Noisy Labels, NeurIPS 2019 + A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + """ + Initialize confidence predictor using configuration options. + + Args: + cfg (CfgNode): configuration options + input_channels (int): number of input channels + """ + # we rely on base predictor to call nn.Module.__init__ + super().__init__(cfg, input_channels) # pyre-ignore[19] + self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg) + self._initialize_confidence_estimation_layers(cfg, input_channels) + self._registry = {} + initialize_module_params(self) # pyre-ignore[6] + + def _initialize_confidence_estimation_layers(self, cfg: CfgNode, dim_in: int): + """ + Initialize confidence estimation layers based on configuration options + + Args: + cfg (CfgNode): configuration options + dim_in (int): number of input channels + """ + dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1 + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL + if self.confidence_model_cfg.uv_confidence.enabled: + if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO: + self.sigma_2_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + elif ( + self.confidence_model_cfg.uv_confidence.type + == DensePoseUVConfidenceType.INDEP_ANISO + ): + self.sigma_2_lowres = ConvTranspose2d( + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + self.kappa_u_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + self.kappa_v_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + else: + raise ValueError( + f"Unknown confidence model type: " + f"{self.confidence_model_cfg.confidence_model_type}" + ) + if self.confidence_model_cfg.segm_confidence.enabled: + self.fine_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + self.coarse_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + + def forward(self, head_outputs: torch.Tensor): + """ + Perform forward operation on head outputs used as inputs for the predictor. + Calls forward method from the base predictor and uses its outputs to compute + confidences. + + Args: + head_outputs (Tensor): head outputs used as predictor inputs + Return: + An instance of outputs with confidences, + see `decorate_predictor_output_class_with_confidences` + """ + # assuming base class returns SIUV estimates in its first result + base_predictor_outputs = super().forward(head_outputs) # pyre-ignore[16] + + # create output instance by extending base predictor outputs: + output = self._create_output_instance(base_predictor_outputs) + + if self.confidence_model_cfg.uv_confidence.enabled: + if self.confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO: + # assuming base class defines interp2d method for bilinear interpolation + output.sigma_2 = self.interp2d(self.sigma_2_lowres(head_outputs)) # pyre-ignore[16] + elif ( + self.confidence_model_cfg.uv_confidence.type + == DensePoseUVConfidenceType.INDEP_ANISO + ): + # assuming base class defines interp2d method for bilinear interpolation + output.sigma_2 = self.interp2d(self.sigma_2_lowres(head_outputs)) + output.kappa_u = self.interp2d(self.kappa_u_lowres(head_outputs)) # pyre-ignore[16] + output.kappa_v = self.interp2d(self.kappa_v_lowres(head_outputs)) # pyre-ignore[16] + else: + raise ValueError( + f"Unknown confidence model type: " + f"{self.confidence_model_cfg.confidence_model_type}" + ) + if self.confidence_model_cfg.segm_confidence.enabled: + # base predictor outputs are assumed to have `fine_segm` and `coarse_segm` attributes + # base predictor is assumed to define `interp2d` method for bilinear interpolation + output.fine_segm_confidence = ( + F.softplus( + self.interp2d(self.fine_segm_confidence_lowres(head_outputs)) # pyre-ignore[16] + ) + + self.confidence_model_cfg.segm_confidence.epsilon + ) + output.fine_segm = base_predictor_outputs.fine_segm * torch.repeat_interleave( + output.fine_segm_confidence, base_predictor_outputs.fine_segm.shape[1], dim=1 + ) + output.coarse_segm_confidence = ( + F.softplus( + self.interp2d( + self.coarse_segm_confidence_lowres(head_outputs) # pyre-ignore[16] + ) + ) + + self.confidence_model_cfg.segm_confidence.epsilon + ) + output.coarse_segm = base_predictor_outputs.coarse_segm * torch.repeat_interleave( + output.coarse_segm_confidence, base_predictor_outputs.coarse_segm.shape[1], dim=1 + ) + + return output + + def _create_output_instance(self, base_predictor_outputs: Any): + """ + Create an instance of predictor outputs by copying the outputs from the + base predictor and initializing confidence + + Args: + base_predictor_outputs: an instance of base predictor outputs + (the outputs type is assumed to be a dataclass) + Return: + An instance of outputs with confidences + """ + PredictorOutput = decorate_predictor_output_class_with_confidences( + type(base_predictor_outputs) # pyre-ignore[6] + ) + # base_predictor_outputs is assumed to be a dataclass + # reassign all the fields from base_predictor_outputs (no deep copy!), add new fields + output = PredictorOutput( + **base_predictor_outputs.__dict__, + coarse_segm_confidence=None, + fine_segm_confidence=None, + sigma_1=None, + sigma_2=None, + kappa_u=None, + kappa_v=None, + ) + return output diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_with_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_with_confidence.py new file mode 100644 index 0000000..9c1cd6c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/chart_with_confidence.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from . import DensePoseChartConfidencePredictorMixin, DensePoseChartPredictor +from .registry import DENSEPOSE_PREDICTOR_REGISTRY + + +@DENSEPOSE_PREDICTOR_REGISTRY.register() +class DensePoseChartWithConfidencePredictor( + DensePoseChartConfidencePredictorMixin, DensePoseChartPredictor +): + """ + Predictor that combines chart and chart confidence estimation + """ + + pass diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse.py new file mode 100644 index 0000000..466a5ec --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import torch +from torch import nn + +from detectron2.config import CfgNode +from detectron2.layers import ConvTranspose2d, interpolate + +from ...structures import DensePoseEmbeddingPredictorOutput +from ..utils import initialize_module_params +from .registry import DENSEPOSE_PREDICTOR_REGISTRY + + +@DENSEPOSE_PREDICTOR_REGISTRY.register() +class DensePoseEmbeddingPredictor(nn.Module): + """ + Last layers of a DensePose model that take DensePose head outputs as an input + and produce model outputs for continuous surface embeddings (CSE). + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + """ + Initialize predictor using configuration options + + Args: + cfg (CfgNode): configuration options + input_channels (int): input tensor size along the channel dimension + """ + super().__init__() + dim_in = input_channels + n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL + # coarse segmentation + self.coarse_segm_lowres = ConvTranspose2d( + dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + # embedding + self.embed_lowres = ConvTranspose2d( + dim_in, embed_size, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE + initialize_module_params(self) + + def interp2d(self, tensor_nchw: torch.Tensor): + """ + Bilinear interpolation method to be used for upscaling + + Args: + tensor_nchw (tensor): tensor of shape (N, C, H, W) + Return: + tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed + by applying the scale factor to H and W + """ + return interpolate( + tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + + def forward(self, head_outputs): + """ + Perform forward step on DensePose head outputs + + Args: + head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W] + """ + embed_lowres = self.embed_lowres(head_outputs) + coarse_segm_lowres = self.coarse_segm_lowres(head_outputs) + embed = self.interp2d(embed_lowres) + coarse_segm = self.interp2d(coarse_segm_lowres) + return DensePoseEmbeddingPredictorOutput(embedding=embed, coarse_segm=coarse_segm) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_confidence.py new file mode 100644 index 0000000..8220337 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_confidence.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any +import torch +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.layers import ConvTranspose2d + +from densepose.modeling.confidence import DensePoseConfidenceModelConfig +from densepose.modeling.utils import initialize_module_params +from densepose.structures import decorate_cse_predictor_output_class_with_confidences + + +class DensePoseEmbeddingConfidencePredictorMixin: + """ + Predictor contains the last layers of a DensePose model that take DensePose head + outputs as an input and produce model outputs. Confidence predictor mixin is used + to generate confidences for coarse segmentation estimated by some + base predictor. Several assumptions need to hold for the base predictor: + 1) the `forward` method must return CSE DensePose head outputs, + tensor of shape [N, D, H, W] + 2) `interp2d` method must be defined to perform bilinear interpolation; + the same method is typically used for masks and confidences + Confidence predictor mixin provides confidence estimates, as described in: + N. Neverova et al., Correlated Uncertainty for Learning Dense Correspondences + from Noisy Labels, NeurIPS 2019 + A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + """ + Initialize confidence predictor using configuration options. + + Args: + cfg (CfgNode): configuration options + input_channels (int): number of input channels + """ + # we rely on base predictor to call nn.Module.__init__ + super().__init__(cfg, input_channels) # pyre-ignore[19] + self.confidence_model_cfg = DensePoseConfidenceModelConfig.from_cfg(cfg) + self._initialize_confidence_estimation_layers(cfg, input_channels) + self._registry = {} + initialize_module_params(self) # pyre-ignore[6] + + def _initialize_confidence_estimation_layers(self, cfg: CfgNode, dim_in: int): + """ + Initialize confidence estimation layers based on configuration options + + Args: + cfg (CfgNode): configuration options + dim_in (int): number of input channels + """ + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL + if self.confidence_model_cfg.segm_confidence.enabled: + self.coarse_segm_confidence_lowres = ConvTranspose2d( # pyre-ignore[16] + dim_in, 1, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) + ) + + def forward(self, head_outputs: torch.Tensor): + """ + Perform forward operation on head outputs used as inputs for the predictor. + Calls forward method from the base predictor and uses its outputs to compute + confidences. + + Args: + head_outputs (Tensor): head outputs used as predictor inputs + Return: + An instance of outputs with confidences, + see `decorate_cse_predictor_output_class_with_confidences` + """ + # assuming base class returns SIUV estimates in its first result + base_predictor_outputs = super().forward(head_outputs) # pyre-ignore[16] + + # create output instance by extending base predictor outputs: + output = self._create_output_instance(base_predictor_outputs) + + if self.confidence_model_cfg.segm_confidence.enabled: + # base predictor outputs are assumed to have `coarse_segm` attribute + # base predictor is assumed to define `interp2d` method for bilinear interpolation + output.coarse_segm_confidence = ( + F.softplus( + self.interp2d( # pyre-ignore[16] + self.coarse_segm_confidence_lowres(head_outputs) # pyre-ignore[16] + ) + ) + + self.confidence_model_cfg.segm_confidence.epsilon + ) + output.coarse_segm = base_predictor_outputs.coarse_segm * torch.repeat_interleave( + output.coarse_segm_confidence, base_predictor_outputs.coarse_segm.shape[1], dim=1 + ) + + return output + + def _create_output_instance(self, base_predictor_outputs: Any): + """ + Create an instance of predictor outputs by copying the outputs from the + base predictor and initializing confidence + + Args: + base_predictor_outputs: an instance of base predictor outputs + (the outputs type is assumed to be a dataclass) + Return: + An instance of outputs with confidences + """ + PredictorOutput = decorate_cse_predictor_output_class_with_confidences( + type(base_predictor_outputs) # pyre-ignore[6] + ) + # base_predictor_outputs is assumed to be a dataclass + # reassign all the fields from base_predictor_outputs (no deep copy!), add new fields + output = PredictorOutput( + **base_predictor_outputs.__dict__, + coarse_segm_confidence=None, + ) + return output diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_with_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_with_confidence.py new file mode 100644 index 0000000..17ecef6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/cse_with_confidence.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from . import DensePoseEmbeddingConfidencePredictorMixin, DensePoseEmbeddingPredictor +from .registry import DENSEPOSE_PREDICTOR_REGISTRY + + +@DENSEPOSE_PREDICTOR_REGISTRY.register() +class DensePoseEmbeddingWithConfidencePredictor( + DensePoseEmbeddingConfidencePredictorMixin, DensePoseEmbeddingPredictor +): + """ + Predictor that combines CSE and CSE confidence estimation + """ + + pass diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/registry.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/registry.py new file mode 100644 index 0000000..f96901d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/predictors/registry.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.utils.registry import Registry + +DENSEPOSE_PREDICTOR_REGISTRY = Registry("DENSEPOSE_PREDICTOR") diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/__init__.py new file mode 100644 index 0000000..8403589 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .v1convx import DensePoseV1ConvXHead +from .deeplab import DensePoseDeepLabHead +from .registry import ROI_DENSEPOSE_HEAD_REGISTRY +from .roi_head import Decoder, DensePoseROIHeads diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/deeplab.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/deeplab.py new file mode 100644 index 0000000..4e5cb48 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/deeplab.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.layers import Conv2d + +from .registry import ROI_DENSEPOSE_HEAD_REGISTRY + + +@ROI_DENSEPOSE_HEAD_REGISTRY.register() +class DensePoseDeepLabHead(nn.Module): + """ + DensePose head using DeepLabV3 model from + "Rethinking Atrous Convolution for Semantic Image Segmentation" + . + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + super(DensePoseDeepLabHead, self).__init__() + # fmt: off + hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL + norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM + self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS + self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON + # fmt: on + pad_size = kernel_size // 2 + n_channels = input_channels + + self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) # 6, 12, 56 + self.add_module("ASPP", self.ASPP) + + if self.use_nonlocal: + self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True) + self.add_module("NLBlock", self.NLBlock) + # weight_init.c2_msra_fill(self.ASPP) + + for i in range(self.n_stacked_convs): + norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None + layer = Conv2d( + n_channels, + hidden_dim, + kernel_size, + stride=1, + padding=pad_size, + bias=not norm, + norm=norm_module, + ) + weight_init.c2_msra_fill(layer) + n_channels = hidden_dim + layer_name = self._get_layer_name(i) + self.add_module(layer_name, layer) + self.n_out_channels = hidden_dim + # initialize_module_params(self) + + def forward(self, features): + x0 = features + x = self.ASPP(x0) + if self.use_nonlocal: + x = self.NLBlock(x) + output = x + for i in range(self.n_stacked_convs): + layer_name = self._get_layer_name(i) + x = getattr(self, layer_name)(x) + x = F.relu(x) + output = x + return output + + def _get_layer_name(self, i: int): + layer_name = "body_conv_fcn{}".format(i + 1) + return layer_name + + +# Copied from +# https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py +# See https://arxiv.org/pdf/1706.05587.pdf for details +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + modules = [ + nn.Conv2d( + in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False + ), + nn.GroupNorm(32, out_channels), + nn.ReLU(), + ] + super(ASPPConv, self).__init__(*modules) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super(ASPPPooling, self).__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.GroupNorm(32, out_channels), + nn.ReLU(), + ) + + def forward(self, x): + size = x.shape[-2:] + x = super(ASPPPooling, self).forward(x) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates, out_channels): + super(ASPP, self).__init__() + modules = [] + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.GroupNorm(32, out_channels), + nn.ReLU(), + ) + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + modules.append(ASPPConv(in_channels, out_channels, rate1)) + modules.append(ASPPConv(in_channels, out_channels, rate2)) + modules.append(ASPPConv(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), + # nn.BatchNorm2d(out_channels), + nn.ReLU() + # nn.Dropout(0.5) + ) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + +# copied from +# https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_embedded_gaussian.py +# See https://arxiv.org/abs/1711.07971 for details +class _NonLocalBlockND(nn.Module): + def __init__( + self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True + ): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + + self.dimension = dimension + self.sub_sample = sub_sample + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + bn = nn.GroupNorm # (32, hidden_dim) #nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool_layer = nn.MaxPool1d(kernel_size=2) + bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm1d + + self.g = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if bn_layer: + self.W = nn.Sequential( + conv_nd( + in_channels=self.inter_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + ), + bn(32, self.in_channels), + ) + nn.init.constant_(self.W[1].weight, 0) + nn.init.constant_(self.W[1].bias, 0) + else: + self.W = conv_nd( + in_channels=self.inter_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) + + self.theta = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.phi = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if sub_sample: + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + def forward(self, x): + """ + :param x: (b, c, t, h, w) + :return: + """ + + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock2D, self).__init__( + in_channels, + inter_channels=inter_channels, + dimension=2, + sub_sample=sub_sample, + bn_layer=bn_layer, + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/registry.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/registry.py new file mode 100644 index 0000000..e1cea43 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/registry.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.utils.registry import Registry + +ROI_DENSEPOSE_HEAD_REGISTRY = Registry("ROI_DENSEPOSE_HEAD") diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/roi_head.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/roi_head.py new file mode 100644 index 0000000..8f9d9a6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/roi_head.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import numpy as np +from typing import Dict, List, Optional +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn as nn +from torch.nn import functional as F + +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.roi_heads import select_foreground_proposals +from detectron2.structures import ImageList, Instances + +from .. import ( + build_densepose_data_filter, + build_densepose_embedder, + build_densepose_head, + build_densepose_losses, + build_densepose_predictor, + densepose_inference, +) + + +class Decoder(nn.Module): + """ + A semantic segmentation head described in detail in the Panoptic Feature Pyramid Networks paper + (https://arxiv.org/abs/1901.02446). It takes FPN features as input and merges information from + all levels of the FPN into single output. + """ + + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec], in_features): + super(Decoder, self).__init__() + + # fmt: off + self.in_features = in_features + feature_strides = {k: v.stride for k, v in input_shape.items()} + feature_channels = {k: v.channels for k, v in input_shape.items()} + num_classes = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES + conv_dims = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS + self.common_stride = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE + norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM + # fmt: on + + self.scale_heads = [] + for in_feature in self.in_features: + head_ops = [] + head_length = max( + 1, int(np.log2(feature_strides[in_feature]) - np.log2(self.common_stride)) + ) + for k in range(head_length): + conv = Conv2d( + feature_channels[in_feature] if k == 0 else conv_dims, + conv_dims, + kernel_size=3, + stride=1, + padding=1, + bias=not norm, + norm=get_norm(norm, conv_dims), + activation=F.relu, + ) + weight_init.c2_msra_fill(conv) + head_ops.append(conv) + if feature_strides[in_feature] != self.common_stride: + head_ops.append( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + ) + self.scale_heads.append(nn.Sequential(*head_ops)) + self.add_module(in_feature, self.scale_heads[-1]) + self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) + weight_init.c2_msra_fill(self.predictor) + + def forward(self, features: List[torch.Tensor]): + for i, _ in enumerate(self.in_features): + if i == 0: + x = self.scale_heads[i](features[i]) + else: + x = x + self.scale_heads[i](features[i]) + x = self.predictor(x) + return x + + +@ROI_HEADS_REGISTRY.register() +class DensePoseROIHeads(StandardROIHeads): + """ + A Standard ROIHeads which contains an addition of DensePose head. + """ + + def __init__(self, cfg, input_shape): + super().__init__(cfg, input_shape) + self._init_densepose_head(cfg, input_shape) + + def _init_densepose_head(self, cfg, input_shape): + # fmt: off + self.densepose_on = cfg.MODEL.DENSEPOSE_ON + if not self.densepose_on: + return + self.densepose_data_filter = build_densepose_data_filter(cfg) + dp_pooler_resolution = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION + dp_pooler_sampling_ratio = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO + dp_pooler_type = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE + self.use_decoder = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON + # fmt: on + if self.use_decoder: + dp_pooler_scales = (1.0 / input_shape[self.in_features[0]].stride,) + else: + dp_pooler_scales = tuple(1.0 / input_shape[k].stride for k in self.in_features) + in_channels = [input_shape[f].channels for f in self.in_features][0] + + if self.use_decoder: + self.decoder = Decoder(cfg, input_shape, self.in_features) + + self.densepose_pooler = ROIPooler( + output_size=dp_pooler_resolution, + scales=dp_pooler_scales, + sampling_ratio=dp_pooler_sampling_ratio, + pooler_type=dp_pooler_type, + ) + self.densepose_head = build_densepose_head(cfg, in_channels) + self.densepose_predictor = build_densepose_predictor( + cfg, self.densepose_head.n_out_channels + ) + self.densepose_losses = build_densepose_losses(cfg) + self.embedder = build_densepose_embedder(cfg) + + def _forward_densepose(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + """ + Forward logic of the densepose prediction branch. + + Args: + features (dict[str, Tensor]): input data as a mapping from feature + map name to tensor. Axis 0 represents the number of images `N` in + the input data; axes 1-3 are channels, height, and width, which may + vary between feature maps (e.g., if a feature pyramid is used). + instances (list[Instances]): length `N` list of `Instances`. The i-th + `Instances` contains instances for the i-th input image, + In training, they can be the proposals. + In inference, they can be the predicted boxes. + + Returns: + In training, a dict of losses. + In inference, update `instances` with new fields "densepose" and return it. + """ + if not self.densepose_on: + return {} if self.training else instances + + features_list = [features[f] for f in self.in_features] + if self.training: + proposals, _ = select_foreground_proposals(instances, self.num_classes) + features_list, proposals = self.densepose_data_filter(features_list, proposals) + if len(proposals) > 0: + proposal_boxes = [x.proposal_boxes for x in proposals] + + if self.use_decoder: + # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a + # function. + features_list = [self.decoder(features_list)] + + features_dp = self.densepose_pooler(features_list, proposal_boxes) + densepose_head_outputs = self.densepose_head(features_dp) + densepose_predictor_outputs = self.densepose_predictor(densepose_head_outputs) + densepose_loss_dict = self.densepose_losses( + proposals, densepose_predictor_outputs, embedder=self.embedder + ) + return densepose_loss_dict + else: + pred_boxes = [x.pred_boxes for x in instances] + + if self.use_decoder: + # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function. + features_list = [self.decoder(features_list)] + + features_dp = self.densepose_pooler(features_list, pred_boxes) + if len(features_dp) > 0: + densepose_head_outputs = self.densepose_head(features_dp) + densepose_predictor_outputs = self.densepose_predictor(densepose_head_outputs) + else: + densepose_predictor_outputs = None + + densepose_inference(densepose_predictor_outputs, instances) + return instances + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ): + instances, losses = super().forward(images, features, proposals, targets) + del targets, images + + if self.training: + losses.update(self._forward_densepose(features, instances)) + return instances, losses + + def forward_with_given_boxes( + self, features: Dict[str, torch.Tensor], instances: List[Instances] + ): + """ + Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. + + This is useful for downstream tasks where a box is known, but need to obtain + other attributes (outputs of other heads). + Test-time augmentation also uses this. + + Args: + features: same as in `forward()` + instances (list[Instances]): instances to predict other outputs. Expect the keys + "pred_boxes" and "pred_classes" to exist. + + Returns: + instances (list[Instances]): + the same `Instances` objects, with extra + fields such as `pred_masks` or `pred_keypoints`. + """ + + instances = super().forward_with_given_boxes(features, instances) + instances = self._forward_densepose(features, instances) + return instances diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/v1convx.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/v1convx.py new file mode 100644 index 0000000..df79f65 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/roi_heads/v1convx.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import CfgNode +from detectron2.layers import Conv2d + +from ..utils import initialize_module_params +from .registry import ROI_DENSEPOSE_HEAD_REGISTRY + + +@ROI_DENSEPOSE_HEAD_REGISTRY.register() +class DensePoseV1ConvXHead(nn.Module): + """ + Fully convolutional DensePose head. + """ + + def __init__(self, cfg: CfgNode, input_channels: int): + """ + Initialize DensePose fully convolutional head + + Args: + cfg (CfgNode): configuration options + input_channels (int): number of input channels + """ + super(DensePoseV1ConvXHead, self).__init__() + # fmt: off + hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM + kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL + self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS + # fmt: on + pad_size = kernel_size // 2 + n_channels = input_channels + for i in range(self.n_stacked_convs): + layer = Conv2d(n_channels, hidden_dim, kernel_size, stride=1, padding=pad_size) + layer_name = self._get_layer_name(i) + self.add_module(layer_name, layer) + n_channels = hidden_dim + self.n_out_channels = n_channels + initialize_module_params(self) + + def forward(self, features: torch.Tensor): + """ + Apply DensePose fully convolutional head to the input features + + Args: + features (tensor): input features + Result: + A tensor of DensePose head outputs + """ + x = features + output = x + for i in range(self.n_stacked_convs): + layer_name = self._get_layer_name(i) + x = getattr(self, layer_name)(x) + x = F.relu(x) + output = x + return output + + def _get_layer_name(self, i: int): + layer_name = "body_conv_fcn{}".format(i + 1) + return layer_name diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/test_time_augmentation.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/test_time_augmentation.py new file mode 100644 index 0000000..ec2022e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/test_time_augmentation.py @@ -0,0 +1,207 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import numpy as np +import torch +from fvcore.transforms import HFlipTransform, TransformList +from torch.nn import functional as F + +from detectron2.data.transforms import RandomRotation, RotationTransform, apply_transform_gens +from detectron2.modeling.postprocessing import detector_postprocess +from detectron2.modeling.test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA + +from ..converters import HFlipConverter + + +class DensePoseDatasetMapperTTA(DatasetMapperTTA): + def __init__(self, cfg): + super().__init__(cfg=cfg) + self.angles = cfg.TEST.AUG.ROTATION_ANGLES + + def __call__(self, dataset_dict): + ret = super().__call__(dataset_dict=dataset_dict) + numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() + for angle in self.angles: + rotate = RandomRotation(angle=angle, expand=True) + new_numpy_image, tfms = apply_transform_gens([rotate], np.copy(numpy_image)) + torch_image = torch.from_numpy(np.ascontiguousarray(new_numpy_image.transpose(2, 0, 1))) + dic = copy.deepcopy(dataset_dict) + # In DatasetMapperTTA, there is a pre_tfm transform (resize or no-op) that is + # added at the beginning of each TransformList. That's '.transforms[0]'. + dic["transforms"] = TransformList( + [ret[-1]["transforms"].transforms[0]] + tfms.transforms + ) + dic["image"] = torch_image + ret.append(dic) + return ret + + +class DensePoseGeneralizedRCNNWithTTA(GeneralizedRCNNWithTTA): + def __init__(self, cfg, model, transform_data, tta_mapper=None, batch_size=1): + """ + Args: + cfg (CfgNode): + model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on. + transform_data (DensePoseTransformData): contains symmetry label + transforms used for horizontal flip + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + self._transform_data = transform_data.to(model.device) + super().__init__(cfg=cfg, model=model, tta_mapper=tta_mapper, batch_size=batch_size) + + # the implementation follows closely the one from detectron2/modeling + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + + Returns: + dict: one output dict + """ + orig_shape = (input["height"], input["width"]) + # For some reason, resize with uint8 slightly increases box AP but decreases densepose AP + input["image"] = input["image"].to(torch.uint8) + augmented_inputs, tfms = self._get_augmented_inputs(input) + # Detect boxes from all augmented versions + with self._turn_off_roi_heads(["mask_on", "keypoint_on", "densepose_on"]): + # temporarily disable roi heads + all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) + merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) + + if self.cfg.MODEL.MASK_ON or self.cfg.MODEL.DENSEPOSE_ON: + # Use the detected boxes to obtain new fields + augmented_instances = self._rescale_detected_boxes( + augmented_inputs, merged_instances, tfms + ) + # run forward on the detected boxes + outputs = self._batch_inference(augmented_inputs, augmented_instances) + # Delete now useless variables to avoid being out of memory + del augmented_inputs, augmented_instances + # average the predictions + if self.cfg.MODEL.MASK_ON: + merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms) + if self.cfg.MODEL.DENSEPOSE_ON: + merged_instances.pred_densepose = self._reduce_pred_densepose(outputs, tfms) + # postprocess + merged_instances = detector_postprocess(merged_instances, *orig_shape) + return {"instances": merged_instances} + else: + return {"instances": merged_instances} + + def _get_augmented_boxes(self, augmented_inputs, tfms): + # Heavily based on detectron2/modeling/test_time_augmentation.py + # Only difference is that RotationTransform is excluded from bbox computation + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # 2: union the results + all_boxes = [] + all_scores = [] + all_classes = [] + for output, tfm in zip(outputs, tfms): + # Need to inverse the transforms on boxes, to obtain results on original image + if not any(isinstance(t, RotationTransform) for t in tfm.transforms): + # Some transforms can't compute bbox correctly + pred_boxes = output.pred_boxes.tensor + original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy()) + all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device)) + all_scores.extend(output.scores) + all_classes.extend(output.pred_classes) + all_boxes = torch.cat(all_boxes, dim=0) + return all_boxes, all_scores, all_classes + + def _reduce_pred_densepose(self, outputs, tfms): + # Should apply inverse transforms on densepose preds. + # We assume only rotation, resize & flip are used. pred_masks is a scale-invariant + # representation, so we handle the other ones specially + for idx, (output, tfm) in enumerate(zip(outputs, tfms)): + for t in tfm.transforms: + for attr in ["coarse_segm", "fine_segm", "u", "v"]: + setattr( + output.pred_densepose, + attr, + _inverse_rotation( + getattr(output.pred_densepose, attr), output.pred_boxes.tensor, t + ), + ) + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + output.pred_densepose = HFlipConverter.convert( + output.pred_densepose, self._transform_data + ) + self._incremental_avg_dp(outputs[0].pred_densepose, output.pred_densepose, idx) + return outputs[0].pred_densepose + + # incrementally computed average: u_(n + 1) = u_n + (x_(n+1) - u_n) / (n + 1). + def _incremental_avg_dp(self, avg, new_el, idx): + for attr in ["coarse_segm", "fine_segm", "u", "v"]: + setattr(avg, attr, (getattr(avg, attr) * idx + getattr(new_el, attr)) / (idx + 1)) + if idx: + # Deletion of the > 0 index intermediary values to prevent GPU OOM + setattr(new_el, attr, None) + return avg + + +def _inverse_rotation(densepose_attrs, boxes, transform): + # resample outputs to image size and rotate back the densepose preds + # on the rotated images to the space of the original image + if len(boxes) == 0 or not isinstance(transform, RotationTransform): + return densepose_attrs + boxes = boxes.int().cpu().numpy() + wh_boxes = boxes[:, 2:] - boxes[:, :2] # bboxes in the rotated space + inv_boxes = rotate_box_inverse(transform, boxes).astype(int) # bboxes in original image + wh_diff = (inv_boxes[:, 2:] - inv_boxes[:, :2] - wh_boxes) // 2 # diff between new/old bboxes + rotation_matrix = torch.tensor([transform.rm_image]).to(device=densepose_attrs.device).float() + rotation_matrix[:, :, -1] = 0 + # To apply grid_sample for rotation, we need to have enough space to fit the original and + # rotated bboxes. l_bds and r_bds are the left/right bounds that will be used to + # crop the difference once the rotation is done + l_bds = np.maximum(0, -wh_diff) + for i in range(len(densepose_attrs)): + if min(wh_boxes[i]) <= 0: + continue + densepose_attr = densepose_attrs[[i]].clone() + # 1. Interpolate densepose attribute to size of the rotated bbox + densepose_attr = F.interpolate(densepose_attr, wh_boxes[i].tolist()[::-1], mode="bilinear") + # 2. Pad the interpolated attribute so it has room for the original + rotated bbox + densepose_attr = F.pad(densepose_attr, tuple(np.repeat(np.maximum(0, wh_diff[i]), 2))) + # 3. Compute rotation grid and transform + grid = F.affine_grid(rotation_matrix, size=densepose_attr.shape) + densepose_attr = F.grid_sample(densepose_attr, grid) + # 4. Compute right bounds and crop the densepose_attr to the size of the original bbox + r_bds = densepose_attr.shape[2:][::-1] - l_bds[i] + densepose_attr = densepose_attr[:, :, l_bds[i][1] : r_bds[1], l_bds[i][0] : r_bds[0]] + if min(densepose_attr.shape) > 0: + # Interpolate back to the original size of the densepose attribute + densepose_attr = F.interpolate( + densepose_attr, densepose_attrs.shape[-2:], mode="bilinear" + ) + # Adding a very small probability to the background class to fill padded zones + densepose_attr[:, 0] += 1e-10 + densepose_attrs[i] = densepose_attr + return densepose_attrs + + +def rotate_box_inverse(rot_tfm, rotated_box): + """ + rotated_box is a N * 4 array of [x0, y0, x1, y1] boxes + When a bbox is rotated, it gets bigger, because we need to surround the tilted bbox + So when a bbox is rotated then inverse-rotated, it is much bigger than the original + This function aims to invert the rotation on the box, but also resize it to its original size + """ + # 1. Compute the inverse rotation of the rotated bboxes (bigger than it ) + invrot_box = rot_tfm.inverse().apply_box(rotated_box) + h, w = rotated_box[:, 3] - rotated_box[:, 1], rotated_box[:, 2] - rotated_box[:, 0] + ih, iw = invrot_box[:, 3] - invrot_box[:, 1], invrot_box[:, 2] - invrot_box[:, 0] + assert 2 * rot_tfm.abs_sin**2 != 1, "45 degrees angle can't be inverted" + # 2. Inverse the corresponding computation in the rotation transform + # to get the original height/width of the rotated boxes + orig_h = (h * rot_tfm.abs_cos - w * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2) + orig_w = (w * rot_tfm.abs_cos - h * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2) + # 3. Resize the inverse-rotated bboxes to their original size + invrot_box[:, 0] += (iw - orig_w) / 2 + invrot_box[:, 1] += (ih - orig_h) / 2 + invrot_box[:, 2] -= (iw - orig_w) / 2 + invrot_box[:, 3] -= (ih - orig_h) / 2 + + return invrot_box diff --git a/data_processing/detectron2/projects/DensePose/densepose/modeling/utils.py b/data_processing/detectron2/projects/DensePose/densepose/modeling/utils.py new file mode 100644 index 0000000..2e76eb9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/modeling/utils.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from torch import nn + + +def initialize_module_params(module: nn.Module) -> None: + for name, param in module.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/structures/__init__.py new file mode 100644 index 0000000..ed32c5e --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .chart import DensePoseChartPredictorOutput +from .chart_confidence import decorate_predictor_output_class_with_confidences +from .cse_confidence import decorate_cse_predictor_output_class_with_confidences +from .chart_result import ( + DensePoseChartResult, + DensePoseChartResultWithConfidences, + quantize_densepose_chart_result, + compress_quantized_densepose_chart_result, + decompress_compressed_densepose_chart_result, +) +from .cse import DensePoseEmbeddingPredictorOutput +from .data_relative import DensePoseDataRelative +from .list import DensePoseList +from .mesh import Mesh, create_mesh +from .transform_data import DensePoseTransformData, normalized_coords_transform diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/chart.py b/data_processing/detectron2/projects/DensePose/densepose/structures/chart.py new file mode 100644 index 0000000..115cc08 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/chart.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import dataclass +from typing import Union +import torch + + +@dataclass +class DensePoseChartPredictorOutput: + """ + Predictor output that contains segmentation and inner coordinates predictions for predefined + body parts: + * coarse segmentation, a tensor of shape [N, K, Hout, Wout] + * fine segmentation, a tensor of shape [N, C, Hout, Wout] + * U coordinates, a tensor of shape [N, C, Hout, Wout] + * V coordinates, a tensor of shape [N, C, Hout, Wout] + where + - N is the number of instances + - K is the number of coarse segmentation channels ( + 2 = foreground / background, + 15 = one of 14 body parts / background) + - C is the number of fine segmentation channels ( + 24 fine body parts / background) + - Hout and Wout are height and width of predictions + """ + + coarse_segm: torch.Tensor + fine_segm: torch.Tensor + u: torch.Tensor + v: torch.Tensor + + def __len__(self): + """ + Number of instances (N) in the output + """ + return self.coarse_segm.size(0) + + def __getitem__( + self, item: Union[int, slice, torch.BoolTensor] + ) -> "DensePoseChartPredictorOutput": + """ + Get outputs for the selected instance(s) + + Args: + item (int or slice or tensor): selected items + """ + if isinstance(item, int): + return DensePoseChartPredictorOutput( + coarse_segm=self.coarse_segm[item].unsqueeze(0), + fine_segm=self.fine_segm[item].unsqueeze(0), + u=self.u[item].unsqueeze(0), + v=self.v[item].unsqueeze(0), + ) + else: + return DensePoseChartPredictorOutput( + coarse_segm=self.coarse_segm[item], + fine_segm=self.fine_segm[item], + u=self.u[item], + v=self.v[item], + ) + + def to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + coarse_segm = self.coarse_segm.to(device) + fine_segm = self.fine_segm.to(device) + u = self.u.to(device) + v = self.v.to(device) + return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v) diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/chart_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/structures/chart_confidence.py new file mode 100644 index 0000000..57c6325 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/chart_confidence.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import make_dataclass +from functools import lru_cache +from typing import Any, Optional +import torch + + +@lru_cache(maxsize=None) +def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: + """ + Create a new output class from an existing one by adding new attributes + related to confidence estimation: + - sigma_1 (tensor) + - sigma_2 (tensor) + - kappa_u (tensor) + - kappa_v (tensor) + - fine_segm_confidence (tensor) + - coarse_segm_confidence (tensor) + + Details on confidence estimation parameters can be found in: + N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning + Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 + A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 + + The new class inherits the provided `BasePredictorOutput` class, + it's name is composed of the name of the provided class and + "WithConfidences" suffix. + + Args: + BasePredictorOutput (type): output type to which confidence data + is to be added, assumed to be a dataclass + Return: + New dataclass derived from the provided one that has attributes + for confidence estimation + """ + + PredictorOutput = make_dataclass( + BasePredictorOutput.__name__ + "WithConfidences", + fields=[ + ("sigma_1", Optional[torch.Tensor], None), + ("sigma_2", Optional[torch.Tensor], None), + ("kappa_u", Optional[torch.Tensor], None), + ("kappa_v", Optional[torch.Tensor], None), + ("fine_segm_confidence", Optional[torch.Tensor], None), + ("coarse_segm_confidence", Optional[torch.Tensor], None), + ], + bases=(BasePredictorOutput,), + ) + + # add possibility to index PredictorOutput + + def slice_if_not_none(data, item): + if data is None: + return None + if isinstance(item, int): + return data[item].unsqueeze(0) + return data[item] + + def PredictorOutput_getitem(self, item): + PredictorOutput = type(self) + base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) + return PredictorOutput( + **base_predictor_output_sliced.__dict__, + coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), + fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item), + sigma_1=slice_if_not_none(self.sigma_1, item), + sigma_2=slice_if_not_none(self.sigma_2, item), + kappa_u=slice_if_not_none(self.kappa_u, item), + kappa_v=slice_if_not_none(self.kappa_v, item), + ) + + PredictorOutput.__getitem__ = PredictorOutput_getitem + + def PredictorOutput_to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + PredictorOutput = type(self) + base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16] + + def to_device_if_tensor(var: Any): + if isinstance(var, torch.Tensor): + return var.to(device) + return var + + return PredictorOutput( + **base_predictor_output_to.__dict__, + sigma_1=to_device_if_tensor(self.sigma_1), + sigma_2=to_device_if_tensor(self.sigma_2), + kappa_u=to_device_if_tensor(self.kappa_u), + kappa_v=to_device_if_tensor(self.kappa_v), + fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), + coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), + ) + + PredictorOutput.to = PredictorOutput_to + return PredictorOutput diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/chart_result.py b/data_processing/detectron2/projects/DensePose/densepose/structures/chart_result.py new file mode 100644 index 0000000..003933d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/chart_result.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import dataclass +from typing import Any, Optional, Tuple +import torch + + +@dataclass +class DensePoseChartResult: + """ + DensePose results for chart-based methods represented by labels and inner + coordinates (U, V) of individual charts. Each chart is a 2D manifold + that has an associated label and is parameterized by two coordinates U and V. + Both U and V take values in [0, 1]. + Thus the results are represented by two tensors: + - labels (tensor [H, W] of long): contains estimated label for each pixel of + the detection bounding box of size (H, W) + - uv (tensor [2, H, W] of float): contains estimated U and V coordinates + for each pixel of the detection bounding box of size (H, W) + """ + + labels: torch.Tensor + uv: torch.Tensor + + def to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + labels = self.labels.to(device) + uv = self.uv.to(device) + return DensePoseChartResult(labels=labels, uv=uv) + + +@dataclass +class DensePoseChartResultWithConfidences: + """ + We add confidence values to DensePoseChartResult + Thus the results are represented by two tensors: + - labels (tensor [H, W] of long): contains estimated label for each pixel of + the detection bounding box of size (H, W) + - uv (tensor [2, H, W] of float): contains estimated U and V coordinates + for each pixel of the detection bounding box of size (H, W) + Plus one [H, W] tensor of float for each confidence type + """ + + labels: torch.Tensor + uv: torch.Tensor + sigma_1: Optional[torch.Tensor] = None + sigma_2: Optional[torch.Tensor] = None + kappa_u: Optional[torch.Tensor] = None + kappa_v: Optional[torch.Tensor] = None + fine_segm_confidence: Optional[torch.Tensor] = None + coarse_segm_confidence: Optional[torch.Tensor] = None + + def to(self, device: torch.device): + """ + Transfers all tensors to the given device, except if their value is None + """ + + def to_device_if_tensor(var: Any): + if isinstance(var, torch.Tensor): + return var.to(device) + return var + + return DensePoseChartResultWithConfidences( + labels=self.labels.to(device), + uv=self.uv.to(device), + sigma_1=to_device_if_tensor(self.sigma_1), + sigma_2=to_device_if_tensor(self.sigma_2), + kappa_u=to_device_if_tensor(self.kappa_u), + kappa_v=to_device_if_tensor(self.kappa_v), + fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), + coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), + ) + + +@dataclass +class DensePoseChartResultQuantized: + """ + DensePose results for chart-based methods represented by labels and quantized + inner coordinates (U, V) of individual charts. Each chart is a 2D manifold + that has an associated label and is parameterized by two coordinates U and V. + Both U and V take values in [0, 1]. + Quantized coordinates Uq and Vq have uint8 values which are obtained as: + Uq = U * 255 (hence 0 <= Uq <= 255) + Vq = V * 255 (hence 0 <= Vq <= 255) + Thus the results are represented by one tensor: + - labels_uv_uint8 (tensor [3, H, W] of uint8): contains estimated label + and quantized coordinates Uq and Vq for each pixel of the detection + bounding box of size (H, W) + """ + + labels_uv_uint8: torch.Tensor + + def to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + labels_uv_uint8 = self.labels_uv_uint8.to(device) + return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8) + + +@dataclass +class DensePoseChartResultCompressed: + """ + DensePose results for chart-based methods represented by a PNG-encoded string. + The tensor of quantized DensePose results of size [3, H, W] is considered + as an image with 3 color channels. PNG compression is applied and the result + is stored as a Base64-encoded string. The following attributes are defined: + - shape_chw (tuple of 3 int): contains shape of the result tensor + (number of channels, height, width) + - labels_uv_str (str): contains Base64-encoded results tensor of size + [3, H, W] compressed with PNG compression methods + """ + + shape_chw: Tuple[int, int, int] + labels_uv_str: str + + +def quantize_densepose_chart_result(result: DensePoseChartResult) -> DensePoseChartResultQuantized: + """ + Applies quantization to DensePose chart-based result. + + Args: + result (DensePoseChartResult): DensePose chart-based result + Return: + Quantized DensePose chart-based result (DensePoseChartResultQuantized) + """ + h, w = result.labels.shape + labels_uv_uint8 = torch.zeros([3, h, w], dtype=torch.uint8, device=result.labels.device) + labels_uv_uint8[0] = result.labels + labels_uv_uint8[1:] = (result.uv * 255).clamp(0, 255).byte() + return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8) + + +def compress_quantized_densepose_chart_result( + result: DensePoseChartResultQuantized, +) -> DensePoseChartResultCompressed: + """ + Compresses quantized DensePose chart-based result + + Args: + result (DensePoseChartResultQuantized): quantized DensePose chart-based result + Return: + Compressed DensePose chart-based result (DensePoseChartResultCompressed) + """ + import base64 + import numpy as np + from io import BytesIO + from PIL import Image + + labels_uv_uint8_np_chw = result.labels_uv_uint8.cpu().numpy() + labels_uv_uint8_np_hwc = np.moveaxis(labels_uv_uint8_np_chw, 0, -1) + im = Image.fromarray(labels_uv_uint8_np_hwc) + fstream = BytesIO() + im.save(fstream, format="png", optimize=True) + labels_uv_str = base64.encodebytes(fstream.getvalue()).decode() + shape_chw = labels_uv_uint8_np_chw.shape + return DensePoseChartResultCompressed(labels_uv_str=labels_uv_str, shape_chw=shape_chw) + + +def decompress_compressed_densepose_chart_result( + result: DensePoseChartResultCompressed, +) -> DensePoseChartResultQuantized: + """ + Decompresses DensePose chart-based result encoded into a base64 string + + Args: + result (DensePoseChartResultCompressed): compressed DensePose chart result + Return: + Quantized DensePose chart-based result (DensePoseChartResultQuantized) + """ + import base64 + import numpy as np + from io import BytesIO + from PIL import Image + + fstream = BytesIO(base64.decodebytes(result.labels_uv_str.encode())) + im = Image.open(fstream) + labels_uv_uint8_np_chw = np.moveaxis(np.array(im, dtype=np.uint8), -1, 0) + return DensePoseChartResultQuantized( + labels_uv_uint8=torch.from_numpy(labels_uv_uint8_np_chw.reshape(result.shape_chw)) + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/cse.py b/data_processing/detectron2/projects/DensePose/densepose/structures/cse.py new file mode 100644 index 0000000..9cd65da --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/cse.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from dataclasses import dataclass +from typing import Union +import torch + + +@dataclass +class DensePoseEmbeddingPredictorOutput: + """ + Predictor output that contains embedding and coarse segmentation data: + * embedding: float tensor of size [N, D, H, W], contains estimated embeddings + * coarse_segm: float tensor of size [N, K, H, W] + Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE + K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS + """ + + embedding: torch.Tensor + coarse_segm: torch.Tensor + + def __len__(self): + """ + Number of instances (N) in the output + """ + return self.coarse_segm.size(0) + + def __getitem__( + self, item: Union[int, slice, torch.BoolTensor] + ) -> "DensePoseEmbeddingPredictorOutput": + """ + Get outputs for the selected instance(s) + + Args: + item (int or slice or tensor): selected items + """ + if isinstance(item, int): + return DensePoseEmbeddingPredictorOutput( + coarse_segm=self.coarse_segm[item].unsqueeze(0), + embedding=self.embedding[item].unsqueeze(0), + ) + else: + return DensePoseEmbeddingPredictorOutput( + coarse_segm=self.coarse_segm[item], embedding=self.embedding[item] + ) + + def to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + coarse_segm = self.coarse_segm.to(device) + embedding = self.embedding.to(device) + return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding) diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/cse_confidence.py b/data_processing/detectron2/projects/DensePose/densepose/structures/cse_confidence.py new file mode 100644 index 0000000..ee5166f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/cse_confidence.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from dataclasses import make_dataclass +from functools import lru_cache +from typing import Any, Optional +import torch + + +@lru_cache(maxsize=None) +def decorate_cse_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: + """ + Create a new output class from an existing one by adding new attributes + related to confidence estimation: + - coarse_segm_confidence (tensor) + + Details on confidence estimation parameters can be found in: + N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning + Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 + A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 + + The new class inherits the provided `BasePredictorOutput` class, + it's name is composed of the name of the provided class and + "WithConfidences" suffix. + + Args: + BasePredictorOutput (type): output type to which confidence data + is to be added, assumed to be a dataclass + Return: + New dataclass derived from the provided one that has attributes + for confidence estimation + """ + + PredictorOutput = make_dataclass( + BasePredictorOutput.__name__ + "WithConfidences", + fields=[ + ("coarse_segm_confidence", Optional[torch.Tensor], None), + ], + bases=(BasePredictorOutput,), + ) + + # add possibility to index PredictorOutput + + def slice_if_not_none(data, item): + if data is None: + return None + if isinstance(item, int): + return data[item].unsqueeze(0) + return data[item] + + def PredictorOutput_getitem(self, item): + PredictorOutput = type(self) + base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) + return PredictorOutput( + **base_predictor_output_sliced.__dict__, + coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), + ) + + PredictorOutput.__getitem__ = PredictorOutput_getitem + + def PredictorOutput_to(self, device: torch.device): + """ + Transfers all tensors to the given device + """ + PredictorOutput = type(self) + base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16] + + def to_device_if_tensor(var: Any): + if isinstance(var, torch.Tensor): + return var.to(device) + return var + + return PredictorOutput( + **base_predictor_output_to.__dict__, + coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), + ) + + PredictorOutput.to = PredictorOutput_to + return PredictorOutput diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/data_relative.py b/data_processing/detectron2/projects/DensePose/densepose/structures/data_relative.py new file mode 100644 index 0000000..a148fa7 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/data_relative.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import torch +from torch.nn import functional as F + +from densepose.data.meshes.catalog import MeshCatalog +from densepose.structures.mesh import load_mesh_symmetry +from densepose.structures.transform_data import DensePoseTransformData + + +class DensePoseDataRelative(object): + """ + Dense pose relative annotations that can be applied to any bounding box: + x - normalized X coordinates [0, 255] of annotated points + y - normalized Y coordinates [0, 255] of annotated points + i - body part labels 0,...,24 for annotated points + u - body part U coordinates [0, 1] for annotated points + v - body part V coordinates [0, 1] for annotated points + segm - 256x256 segmentation mask with values 0,...,14 + To obtain absolute x and y data wrt some bounding box one needs to first + divide the data by 256, multiply by the respective bounding box size + and add bounding box offset: + x_img = x0 + x_norm * w / 256.0 + y_img = y0 + y_norm * h / 256.0 + Segmentation masks are typically sampled to get image-based masks. + """ + + # Key for normalized X coordinates in annotation dict + X_KEY = "dp_x" + # Key for normalized Y coordinates in annotation dict + Y_KEY = "dp_y" + # Key for U part coordinates in annotation dict (used in chart-based annotations) + U_KEY = "dp_U" + # Key for V part coordinates in annotation dict (used in chart-based annotations) + V_KEY = "dp_V" + # Key for I point labels in annotation dict (used in chart-based annotations) + I_KEY = "dp_I" + # Key for segmentation mask in annotation dict + S_KEY = "dp_masks" + # Key for vertex ids (used in continuous surface embeddings annotations) + VERTEX_IDS_KEY = "dp_vertex" + # Key for mesh id (used in continuous surface embeddings annotations) + MESH_NAME_KEY = "ref_model" + # Number of body parts in segmentation masks + N_BODY_PARTS = 14 + # Number of parts in point labels + N_PART_LABELS = 24 + MASK_SIZE = 256 + + def __init__(self, annotation, cleanup=False): + self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY]) + self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY]) + if ( + DensePoseDataRelative.I_KEY in annotation + and DensePoseDataRelative.U_KEY in annotation + and DensePoseDataRelative.V_KEY in annotation + ): + self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY]) + self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY]) + self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY]) + if ( + DensePoseDataRelative.VERTEX_IDS_KEY in annotation + and DensePoseDataRelative.MESH_NAME_KEY in annotation + ): + self.vertex_ids = torch.as_tensor( + annotation[DensePoseDataRelative.VERTEX_IDS_KEY], dtype=torch.long + ) + self.mesh_id = MeshCatalog.get_mesh_id(annotation[DensePoseDataRelative.MESH_NAME_KEY]) + if DensePoseDataRelative.S_KEY in annotation: + self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation) + self.device = torch.device("cpu") + if cleanup: + DensePoseDataRelative.cleanup_annotation(annotation) + + def to(self, device): + if self.device == device: + return self + new_data = DensePoseDataRelative.__new__(DensePoseDataRelative) + new_data.x = self.x.to(device) + new_data.y = self.y.to(device) + for attr in ["i", "u", "v", "vertex_ids", "segm"]: + if hasattr(self, attr): + setattr(new_data, attr, getattr(self, attr).to(device)) + if hasattr(self, "mesh_id"): + new_data.mesh_id = self.mesh_id + new_data.device = device + return new_data + + @staticmethod + def extract_segmentation_mask(annotation): + import pycocotools.mask as mask_utils + + # TODO: annotation instance is accepted if it contains either + # DensePose segmentation or instance segmentation. However, here we + # only rely on DensePose segmentation + poly_specs = annotation[DensePoseDataRelative.S_KEY] + if isinstance(poly_specs, torch.Tensor): + # data is already given as mask tensors, no need to decode + return poly_specs + segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32) + if isinstance(poly_specs, dict): + if poly_specs: + mask = mask_utils.decode(poly_specs) + segm[mask > 0] = 1 + else: + for i in range(len(poly_specs)): + poly_i = poly_specs[i] + if poly_i: + mask_i = mask_utils.decode(poly_i) + segm[mask_i > 0] = i + 1 + return segm + + @staticmethod + def validate_annotation(annotation): + for key in [ + DensePoseDataRelative.X_KEY, + DensePoseDataRelative.Y_KEY, + ]: + if key not in annotation: + return False, "no {key} data in the annotation".format(key=key) + valid_for_iuv_setting = all( + key in annotation + for key in [ + DensePoseDataRelative.I_KEY, + DensePoseDataRelative.U_KEY, + DensePoseDataRelative.V_KEY, + ] + ) + valid_for_cse_setting = all( + key in annotation + for key in [ + DensePoseDataRelative.VERTEX_IDS_KEY, + DensePoseDataRelative.MESH_NAME_KEY, + ] + ) + if not valid_for_iuv_setting and not valid_for_cse_setting: + return ( + False, + "expected either {} (IUV setting) or {} (CSE setting) annotations".format( + ", ".join( + [ + DensePoseDataRelative.I_KEY, + DensePoseDataRelative.U_KEY, + DensePoseDataRelative.V_KEY, + ] + ), + ", ".join( + [ + DensePoseDataRelative.VERTEX_IDS_KEY, + DensePoseDataRelative.MESH_NAME_KEY, + ] + ), + ), + ) + return True, None + + @staticmethod + def cleanup_annotation(annotation): + for key in [ + DensePoseDataRelative.X_KEY, + DensePoseDataRelative.Y_KEY, + DensePoseDataRelative.I_KEY, + DensePoseDataRelative.U_KEY, + DensePoseDataRelative.V_KEY, + DensePoseDataRelative.S_KEY, + DensePoseDataRelative.VERTEX_IDS_KEY, + DensePoseDataRelative.MESH_NAME_KEY, + ]: + if key in annotation: + del annotation[key] + + def apply_transform(self, transforms, densepose_transform_data): + self._transform_pts(transforms, densepose_transform_data) + if hasattr(self, "segm"): + self._transform_segm(transforms, densepose_transform_data) + + def _transform_pts(self, transforms, dp_transform_data): + import detectron2.data.transforms as T + + # NOTE: This assumes that HorizFlipTransform is the only one that does flip + do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + if do_hflip: + self.x = self.MASK_SIZE - self.x + if hasattr(self, "i"): + self._flip_iuv_semantics(dp_transform_data) + if hasattr(self, "vertex_ids"): + self._flip_vertices() + + for t in transforms.transforms: + if isinstance(t, T.RotationTransform): + xy_scale = np.array((t.w, t.h)) / DensePoseDataRelative.MASK_SIZE + xy = t.apply_coords(np.stack((self.x, self.y), axis=1) * xy_scale) + self.x, self.y = torch.tensor(xy / xy_scale, dtype=self.x.dtype).T + + def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None: + i_old = self.i.clone() + uv_symmetries = dp_transform_data.uv_symmetries + pt_label_symmetries = dp_transform_data.point_label_symmetries + for i in range(self.N_PART_LABELS): + if i + 1 in i_old: + annot_indices_i = i_old == i + 1 + if pt_label_symmetries[i + 1] != i + 1: + self.i[annot_indices_i] = pt_label_symmetries[i + 1] + u_loc = (self.u[annot_indices_i] * 255).long() + v_loc = (self.v[annot_indices_i] * 255).long() + self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc].to( + device=self.u.device + ) + self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc].to( + device=self.v.device + ) + + def _flip_vertices(self): + mesh_info = MeshCatalog[MeshCatalog.get_mesh_name(self.mesh_id)] + mesh_symmetry = ( + load_mesh_symmetry(mesh_info.symmetry) if mesh_info.symmetry is not None else None + ) + self.vertex_ids = mesh_symmetry["vertex_transforms"][self.vertex_ids] + + def _transform_segm(self, transforms, dp_transform_data): + import detectron2.data.transforms as T + + # NOTE: This assumes that HorizFlipTransform is the only one that does flip + do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + if do_hflip: + self.segm = torch.flip(self.segm, [1]) + self._flip_segm_semantics(dp_transform_data) + + for t in transforms.transforms: + if isinstance(t, T.RotationTransform): + self._transform_segm_rotation(t) + + def _flip_segm_semantics(self, dp_transform_data): + old_segm = self.segm.clone() + mask_label_symmetries = dp_transform_data.mask_label_symmetries + for i in range(self.N_BODY_PARTS): + if mask_label_symmetries[i + 1] != i + 1: + self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1] + + def _transform_segm_rotation(self, rotation): + self.segm = F.interpolate(self.segm[None, None, :], (rotation.h, rotation.w)).numpy() + self.segm = torch.tensor(rotation.apply_segmentation(self.segm[0, 0]))[None, None, :] + self.segm = F.interpolate(self.segm, [DensePoseDataRelative.MASK_SIZE] * 2)[0, 0] diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/list.py b/data_processing/detectron2/projects/DensePose/densepose/structures/list.py new file mode 100644 index 0000000..3dc40b0 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/list.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from densepose.structures.data_relative import DensePoseDataRelative + + +class DensePoseList(object): + + _TORCH_DEVICE_CPU = torch.device("cpu") + + def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU): + assert len(densepose_datas) == len( + boxes_xyxy_abs + ), "Attempt to initialize DensePoseList with {} DensePose datas " "and {} boxes".format( + len(densepose_datas), len(boxes_xyxy_abs) + ) + self.densepose_datas = [] + for densepose_data in densepose_datas: + assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, ( + "Attempt to initialize DensePoseList with DensePose datas " + "of type {}, expected DensePoseDataRelative".format(type(densepose_data)) + ) + densepose_data_ondevice = ( + densepose_data.to(device) if densepose_data is not None else None + ) + self.densepose_datas.append(densepose_data_ondevice) + self.boxes_xyxy_abs = boxes_xyxy_abs.to(device) + self.image_size_hw = image_size_hw + self.device = device + + def to(self, device): + if self.device == device: + return self + return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device) + + def __iter__(self): + return iter(self.densepose_datas) + + def __len__(self): + return len(self.densepose_datas) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_instances={}, ".format(len(self.densepose_datas)) + s += "image_width={}, ".format(self.image_size_hw[1]) + s += "image_height={})".format(self.image_size_hw[0]) + return s + + def __getitem__(self, item): + if isinstance(item, int): + densepose_data_rel = self.densepose_datas[item] + return densepose_data_rel + elif isinstance(item, slice): + densepose_datas_rel = self.densepose_datas[item] + boxes_xyxy_abs = self.boxes_xyxy_abs[item] + return DensePoseList( + densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device + ) + elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool): + densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0] + boxes_xyxy_abs = self.boxes_xyxy_abs[item] + return DensePoseList( + densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device + ) + else: + densepose_datas_rel = [self.densepose_datas[i] for i in item] + boxes_xyxy_abs = self.boxes_xyxy_abs[item] + return DensePoseList( + densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/mesh.py b/data_processing/detectron2/projects/DensePose/densepose/structures/mesh.py new file mode 100644 index 0000000..589515d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/mesh.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import pickle +from functools import lru_cache +from typing import Dict, Optional, Tuple +import torch + +from detectron2.utils.file_io import PathManager + +from densepose.data.meshes.catalog import MeshCatalog, MeshInfo + + +def _maybe_copy_to_device( + attribute: Optional[torch.Tensor], device: torch.device +) -> Optional[torch.Tensor]: + if attribute is None: + return None + return attribute.to(device) + + +class Mesh: + def __init__( + self, + vertices: Optional[torch.Tensor] = None, + faces: Optional[torch.Tensor] = None, + geodists: Optional[torch.Tensor] = None, + symmetry: Optional[Dict[str, torch.Tensor]] = None, + texcoords: Optional[torch.Tensor] = None, + mesh_info: Optional[MeshInfo] = None, + device: Optional[torch.device] = None, + ): + """ + Args: + vertices (tensor [N, 3] of float32): vertex coordinates in 3D + faces (tensor [M, 3] of long): triangular face represented as 3 + vertex indices + geodists (tensor [N, N] of float32): geodesic distances from + vertex `i` to vertex `j` (optional, default: None) + symmetry (dict: str -> tensor): various mesh symmetry data: + - "vertex_transforms": vertex mapping under horizontal flip, + tensor of size [N] of type long; vertex `i` is mapped to + vertex `tensor[i]` (optional, default: None) + texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global + and normalized mesh UVs (optional, default: None) + mesh_info (MeshInfo type): necessary to load the attributes on-the-go, + can be used instead of passing all the variables one by one + device (torch.device): device of the Mesh. If not provided, will use + the device of the vertices + """ + self._vertices = vertices + self._faces = faces + self._geodists = geodists + self._symmetry = symmetry + self._texcoords = texcoords + self.mesh_info = mesh_info + self.device = device + + assert self._vertices is not None or self.mesh_info is not None + + all_fields = [self._vertices, self._faces, self._geodists, self._texcoords] + + if self.device is None: + for field in all_fields: + if field is not None: + self.device = field.device + break + if self.device is None and symmetry is not None: + for key in symmetry: + self.device = symmetry[key].device + break + self.device = torch.device("cpu") if self.device is None else self.device + + assert all([var.device == self.device for var in all_fields if var is not None]) + if symmetry: + assert all(symmetry[key].device == self.device for key in symmetry) + if texcoords and vertices: + assert len(vertices) == len(texcoords) + + def to(self, device: torch.device): + device_symmetry = self._symmetry + if device_symmetry: + device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()} + return Mesh( + _maybe_copy_to_device(self._vertices, device), + _maybe_copy_to_device(self._faces, device), + _maybe_copy_to_device(self._geodists, device), + device_symmetry, + _maybe_copy_to_device(self._texcoords, device), + self.mesh_info, + device, + ) + + @property + def vertices(self): + if self._vertices is None and self.mesh_info is not None: + self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device) + return self._vertices + + @property + def faces(self): + if self._faces is None and self.mesh_info is not None: + self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device) + return self._faces + + @property + def geodists(self): + if self._geodists is None and self.mesh_info is not None: + self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device) + return self._geodists + + @property + def symmetry(self): + if self._symmetry is None and self.mesh_info is not None: + self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device) + return self._symmetry + + @property + def texcoords(self): + if self._texcoords is None and self.mesh_info is not None: + self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device) + return self._texcoords + + def get_geodists(self): + if self.geodists is None: + self.geodists = self._compute_geodists() + return self.geodists + + def _compute_geodists(self): + # TODO: compute using Laplace-Beltrami + geodists = None + return geodists + + +def load_mesh_data( + mesh_fpath: str, field: str, device: Optional[torch.device] = None +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + with PathManager.open(mesh_fpath, "rb") as hFile: + # pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but + # got `Tensor`. + return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to( # pyre-ignore[6] + device + ) + return None + + +def load_mesh_auxiliary_data( + fpath: str, device: Optional[torch.device] = None +) -> Optional[torch.Tensor]: + fpath_local = PathManager.get_local_path(fpath) + with PathManager.open(fpath_local, "rb") as hFile: + return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device) # pyre-ignore[6] + return None + + +@lru_cache() +def load_mesh_symmetry( + symmetry_fpath: str, device: Optional[torch.device] = None +) -> Optional[Dict[str, torch.Tensor]]: + with PathManager.open(symmetry_fpath, "rb") as hFile: + symmetry_loaded = pickle.load(hFile) # pyre-ignore[6] + symmetry = { + "vertex_transforms": torch.as_tensor( + symmetry_loaded["vertex_transforms"], dtype=torch.long + ).to(device), + } + return symmetry + return None + + +@lru_cache() +def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh: + return Mesh(mesh_info=MeshCatalog[mesh_name], device=device) diff --git a/data_processing/detectron2/projects/DensePose/densepose/structures/transform_data.py b/data_processing/detectron2/projects/DensePose/densepose/structures/transform_data.py new file mode 100644 index 0000000..7cac1bb --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/structures/transform_data.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import BinaryIO, Dict, Union +import torch + + +def normalized_coords_transform(x0, y0, w, h): + """ + Coordinates transform that maps top left corner to (-1, -1) and bottom + right corner to (1, 1). Used for torch.grid_sample to initialize the + grid + """ + + def f(p): + return (2 * (p[0] - x0) / w - 1, 2 * (p[1] - y0) / h - 1) + + return f + + +class DensePoseTransformData(object): + + # Horizontal symmetry label transforms used for horizontal flip + MASK_LABEL_SYMMETRIES = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14] + # fmt: off + POINT_LABEL_SYMMETRIES = [ 0, 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23] # noqa + # fmt: on + + def __init__(self, uv_symmetries: Dict[str, torch.Tensor], device: torch.device): + self.mask_label_symmetries = DensePoseTransformData.MASK_LABEL_SYMMETRIES + self.point_label_symmetries = DensePoseTransformData.POINT_LABEL_SYMMETRIES + self.uv_symmetries = uv_symmetries + self.device = torch.device("cpu") + + def to(self, device: torch.device, copy: bool = False) -> "DensePoseTransformData": + """ + Convert transform data to the specified device + + Args: + device (torch.device): device to convert the data to + copy (bool): flag that specifies whether to copy or to reference the data + in case the device is the same + Return: + An instance of `DensePoseTransformData` with data stored on the specified device + """ + if self.device == device and not copy: + return self + uv_symmetry_map = {} + for key in self.uv_symmetries: + uv_symmetry_map[key] = self.uv_symmetries[key].to(device=device, copy=copy) + return DensePoseTransformData(uv_symmetry_map, device) + + @staticmethod + def load(io: Union[str, BinaryIO]): + """ + Args: + io: (str or binary file-like object): input file to load data from + Returns: + An instance of `DensePoseTransformData` with transforms loaded from the file + """ + import scipy.io + + uv_symmetry_map = scipy.io.loadmat(io) + uv_symmetry_map_torch = {} + for key in ["U_transforms", "V_transforms"]: + uv_symmetry_map_torch[key] = [] + map_src = uv_symmetry_map[key] + map_dst = uv_symmetry_map_torch[key] + for i in range(map_src.shape[1]): + map_dst.append(torch.from_numpy(map_src[0, i]).to(dtype=torch.float)) + uv_symmetry_map_torch[key] = torch.stack(map_dst, dim=0) + transform_data = DensePoseTransformData(uv_symmetry_map_torch, device=torch.device("cpu")) + return transform_data diff --git a/data_processing/detectron2/projects/DensePose/densepose/utils/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/projects/DensePose/densepose/utils/dbhelper.py b/data_processing/detectron2/projects/DensePose/densepose/utils/dbhelper.py new file mode 100644 index 0000000..65b6157 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/utils/dbhelper.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Any, Dict, Optional, Tuple + + +class EntrySelector(object): + """ + Base class for entry selectors + """ + + @staticmethod + def from_string(spec: str) -> "EntrySelector": + if spec == "*": + return AllEntrySelector() + return FieldEntrySelector(spec) + + +class AllEntrySelector(EntrySelector): + """ + Selector that accepts all entries + """ + + SPECIFIER = "*" + + def __call__(self, entry): + return True + + +class FieldEntrySelector(EntrySelector): + """ + Selector that accepts only entries that match provided field + specifier(s). Only a limited set of specifiers is supported for now: + ::=[] + ::=[] + is a valid identifier + ::= "int" | "str" + ::= "=" + ::= "," + ::= ":" + ::= | + ::= + ::= "-" + is a string without spaces and special symbols + (e.g. , , , ) + """ + + _SPEC_DELIM = "," + _TYPE_DELIM = ":" + _RANGE_DELIM = "-" + _EQUAL = "=" + _ERROR_PREFIX = "Invalid field selector specifier" + + class _FieldEntryValuePredicate(object): + """ + Predicate that checks strict equality for the specified entry field + """ + + def __init__(self, name: str, typespec: Optional[str], value: str): + import builtins + + self.name = name + self.type = getattr(builtins, typespec) if typespec is not None else str + self.value = value + + def __call__(self, entry): + return entry[self.name] == self.type(self.value) + + class _FieldEntryRangePredicate(object): + """ + Predicate that checks whether an entry field falls into the specified range + """ + + def __init__(self, name: str, typespec: Optional[str], vmin: str, vmax: str): + import builtins + + self.name = name + self.type = getattr(builtins, typespec) if typespec is not None else str + self.vmin = vmin + self.vmax = vmax + + def __call__(self, entry): + return (entry[self.name] >= self.type(self.vmin)) and ( + entry[self.name] <= self.type(self.vmax) + ) + + def __init__(self, spec: str): + self._predicates = self._parse_specifier_into_predicates(spec) + + def __call__(self, entry: Dict[str, Any]): + for predicate in self._predicates: + if not predicate(entry): + return False + return True + + def _parse_specifier_into_predicates(self, spec: str): + predicates = [] + specs = spec.split(self._SPEC_DELIM) + for subspec in specs: + eq_idx = subspec.find(self._EQUAL) + if eq_idx > 0: + field_name_with_type = subspec[:eq_idx] + field_name, field_type = self._parse_field_name_type(field_name_with_type) + field_value_or_range = subspec[eq_idx + 1 :] + if self._is_range_spec(field_value_or_range): + vmin, vmax = self._get_range_spec(field_value_or_range) + predicate = FieldEntrySelector._FieldEntryRangePredicate( + field_name, field_type, vmin, vmax + ) + else: + predicate = FieldEntrySelector._FieldEntryValuePredicate( + field_name, field_type, field_value_or_range + ) + predicates.append(predicate) + elif eq_idx == 0: + self._parse_error(f'"{subspec}", field name is empty!') + else: + self._parse_error(f'"{subspec}", should have format ' "=!") + return predicates + + def _parse_field_name_type(self, field_name_with_type: str) -> Tuple[str, Optional[str]]: + type_delim_idx = field_name_with_type.find(self._TYPE_DELIM) + if type_delim_idx > 0: + field_name = field_name_with_type[:type_delim_idx] + field_type = field_name_with_type[type_delim_idx + 1 :] + elif type_delim_idx == 0: + self._parse_error(f'"{field_name_with_type}", field name is empty!') + else: + field_name = field_name_with_type + field_type = None + # pyre-fixme[61]: `field_name` may not be initialized here. + # pyre-fixme[61]: `field_type` may not be initialized here. + return field_name, field_type + + def _is_range_spec(self, field_value_or_range): + delim_idx = field_value_or_range.find(self._RANGE_DELIM) + return delim_idx > 0 + + def _get_range_spec(self, field_value_or_range): + if self._is_range_spec(field_value_or_range): + delim_idx = field_value_or_range.find(self._RANGE_DELIM) + vmin = field_value_or_range[:delim_idx] + vmax = field_value_or_range[delim_idx + 1 :] + return vmin, vmax + else: + self._parse_error('"field_value_or_range", range of values expected!') + + def _parse_error(self, msg): + raise ValueError(f"{self._ERROR_PREFIX}: {msg}") diff --git a/data_processing/detectron2/projects/DensePose/densepose/utils/logger.py b/data_processing/detectron2/projects/DensePose/densepose/utils/logger.py new file mode 100644 index 0000000..70cd3cb --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/utils/logger.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging + + +def verbosity_to_level(verbosity) -> int: + if verbosity is not None: + if verbosity == 0: + return logging.WARNING + elif verbosity == 1: + return logging.INFO + elif verbosity >= 2: + return logging.DEBUG + return logging.WARNING diff --git a/data_processing/detectron2/projects/DensePose/densepose/utils/transform.py b/data_processing/detectron2/projects/DensePose/densepose/utils/transform.py new file mode 100644 index 0000000..8dc4ae7 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/utils/transform.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.data import MetadataCatalog +from detectron2.utils.file_io import PathManager + +from densepose import DensePoseTransformData + + +def load_for_dataset(dataset_name): + path = MetadataCatalog.get(dataset_name).densepose_transform_src + densepose_transform_data_fpath = PathManager.get_local_path(path) + return DensePoseTransformData.load(densepose_transform_data_fpath) + + +def load_from_cfg(cfg): + return load_for_dataset(cfg.DATASETS.TEST[0]) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/__init__.py b/data_processing/detectron2/projects/DensePose/densepose/vis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/base.py b/data_processing/detectron2/projects/DensePose/densepose/vis/base.py new file mode 100644 index 0000000..08b3a98 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/base.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +import cv2 +import torch + +Image = np.ndarray +Boxes = torch.Tensor + + +class MatrixVisualizer(object): + """ + Base visualizer for matrix data + """ + + def __init__( + self, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + val_scale=1.0, + alpha=0.7, + interp_method_matrix=cv2.INTER_LINEAR, + interp_method_mask=cv2.INTER_NEAREST, + ): + self.inplace = inplace + self.cmap = cmap + self.val_scale = val_scale + self.alpha = alpha + self.interp_method_matrix = interp_method_matrix + self.interp_method_mask = interp_method_mask + + def visualize(self, image_bgr, mask, matrix, bbox_xywh): + # self._check_image(image_bgr) + # self._check_mask_matrix(mask, matrix) + # if self.inplace: + # image_target_bgr = image_bgr + # else: + # image_target_bgr = image_bgr * 0 + # x, y, w, h = [int(v) for v in bbox_xywh] + # if w <= 0 or h <= 0: + # return image_bgr + # mask, matrix = self._resize(mask, matrix, w, h) + # mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3]) + # matrix_scaled = matrix.astype(np.float32) * self.val_scale + # _EPSILON = 1e-6 + # if np.any(matrix_scaled > 255 + _EPSILON): + # logger = logging.getLogger(__name__) + # logger.warning( + # f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]" + # ) + # matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8) + # matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap) + # matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg] + # image_target_bgr[y : y + h, x : x + w, :] = ( + # image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha + # ) + # return image_target_bgr.astype(np.uint8) + + self._check_image(image_bgr) + self._check_mask_matrix(mask, matrix) + if self.inplace: + image_target_bgr = image_bgr + else: + image_target_bgr = image_bgr * 0 + x, y, w, h = [int(v) for v in bbox_xywh] + if w <= 0 or h <= 0: + return image_bgr + mask, matrix = self._resize(mask, matrix, w, h) + mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3]) + + + # matrix_scaled = matrix.astype(np.float32) * self.val_scale + # _EPSILON = 1e-6 + # if np.any(matrix_scaled > 255 + _EPSILON): + # logger = logging.getLogger(__name__) + # logger.warning( + # f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]" + # ) + # matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8) + # matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap) + # print('matrix', matrix.min(), matrix.max(),matrix.shape) + matrix = matrix.cpu().numpy() + matrix_vis = np.zeros((matrix.shape[0], matrix.shape[1], 3), dtype=np.uint8) + matrix_vis[:,:,0] = matrix//255 + matrix_vis[:,:,1] = matrix%255 + + + + + matrix_vis[mask_bg] = image_target_bgr[y: y + h, x: x + w, :][mask_bg] + image_target_bgr[y: y + h, x: x + w, :] = matrix_vis + + + return image_target_bgr.astype(np.uint8) + + + def _resize(self, mask, matrix, w, h): + if (w != mask.shape[1]) or (h != mask.shape[0]): + mask = cv2.resize(mask, (w, h), self.interp_method_mask) + if (w != matrix.shape[1]) or (h != matrix.shape[0]): + matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix) + return mask, matrix + + def _check_image(self, image_rgb): + assert len(image_rgb.shape) == 3 + assert image_rgb.shape[2] == 3 + assert image_rgb.dtype == np.uint8 + + def _check_mask_matrix(self, mask, matrix): + assert len(matrix.shape) == 2 + assert len(mask.shape) == 2 + assert mask.dtype == np.uint8 + + +class RectangleVisualizer(object): + + _COLOR_GREEN = (18, 127, 15) + + def __init__(self, color=_COLOR_GREEN, thickness=1): + self.color = color + self.thickness = thickness + + def visualize(self, image_bgr, bbox_xywh, color=None, thickness=None): + x, y, w, h = bbox_xywh + color = color or self.color + thickness = thickness or self.thickness + cv2.rectangle(image_bgr, (int(x), int(y)), (int(x + w), int(y + h)), color, thickness) + return image_bgr + + +class PointsVisualizer(object): + + _COLOR_GREEN = (18, 127, 15) + + def __init__(self, color_bgr=_COLOR_GREEN, r=5): + self.color_bgr = color_bgr + self.r = r + + def visualize(self, image_bgr, pts_xy, colors_bgr=None, rs=None): + for j, pt_xy in enumerate(pts_xy): + x, y = pt_xy + color_bgr = colors_bgr[j] if colors_bgr is not None else self.color_bgr + r = rs[j] if rs is not None else self.r + cv2.circle(image_bgr, (x, y), r, color_bgr, -1) + return image_bgr + + +class TextVisualizer(object): + + _COLOR_GRAY = (218, 227, 218) + _COLOR_WHITE = (255, 255, 255) + + def __init__( + self, + font_face=cv2.FONT_HERSHEY_SIMPLEX, + font_color_bgr=_COLOR_GRAY, + font_scale=0.35, + font_line_type=cv2.LINE_AA, + font_line_thickness=1, + fill_color_bgr=_COLOR_WHITE, + fill_color_transparency=1.0, + frame_color_bgr=_COLOR_WHITE, + frame_color_transparency=1.0, + frame_thickness=1, + ): + self.font_face = font_face + self.font_color_bgr = font_color_bgr + self.font_scale = font_scale + self.font_line_type = font_line_type + self.font_line_thickness = font_line_thickness + self.fill_color_bgr = fill_color_bgr + self.fill_color_transparency = fill_color_transparency + self.frame_color_bgr = frame_color_bgr + self.frame_color_transparency = frame_color_transparency + self.frame_thickness = frame_thickness + + def visualize(self, image_bgr, txt, topleft_xy): + txt_w, txt_h = self.get_text_size_wh(txt) + topleft_xy = tuple(map(int, topleft_xy)) + x, y = topleft_xy + if self.frame_color_transparency < 1.0: + t = self.frame_thickness + image_bgr[y - t : y + txt_h + t, x - t : x + txt_w + t, :] = ( + image_bgr[y - t : y + txt_h + t, x - t : x + txt_w + t, :] + * self.frame_color_transparency + + np.array(self.frame_color_bgr) * (1.0 - self.frame_color_transparency) + ).astype(np.float) + if self.fill_color_transparency < 1.0: + image_bgr[y : y + txt_h, x : x + txt_w, :] = ( + image_bgr[y : y + txt_h, x : x + txt_w, :] * self.fill_color_transparency + + np.array(self.fill_color_bgr) * (1.0 - self.fill_color_transparency) + ).astype(np.float) + cv2.putText( + image_bgr, + txt, + topleft_xy, + self.font_face, + self.font_scale, + self.font_color_bgr, + self.font_line_thickness, + self.font_line_type, + ) + return image_bgr + + def get_text_size_wh(self, txt): + ((txt_w, txt_h), _) = cv2.getTextSize( + txt, self.font_face, self.font_scale, self.font_line_thickness + ) + return txt_w, txt_h + + +class CompoundVisualizer(object): + def __init__(self, visualizers): + self.visualizers = visualizers + + def visualize(self, image_bgr, data): + assert len(data) == len( + self.visualizers + ), "The number of datas {} should match the number of visualizers" " {}".format( + len(data), len(self.visualizers) + ) + image = image_bgr + for i, visualizer in enumerate(self.visualizers): + image = visualizer.visualize(image, data[i]) + return image + + def __str__(self): + visualizer_str = ", ".join([str(v) for v in self.visualizers]) + return "Compound Visualizer [{}]".format(visualizer_str) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/bounding_box.py b/data_processing/detectron2/projects/DensePose/densepose/vis/bounding_box.py new file mode 100644 index 0000000..4f83957 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/bounding_box.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .base import RectangleVisualizer, TextVisualizer + + +class BoundingBoxVisualizer(object): + def __init__(self): + self.rectangle_visualizer = RectangleVisualizer() + + def visualize(self, image_bgr, boxes_xywh): + for bbox_xywh in boxes_xywh: + image_bgr = self.rectangle_visualizer.visualize(image_bgr, bbox_xywh) + return image_bgr + + +class ScoredBoundingBoxVisualizer(object): + def __init__(self, bbox_visualizer_params=None, score_visualizer_params=None, **kwargs): + if bbox_visualizer_params is None: + bbox_visualizer_params = {} + if score_visualizer_params is None: + score_visualizer_params = {} + self.visualizer_bbox = RectangleVisualizer(**bbox_visualizer_params) + self.visualizer_score = TextVisualizer(**score_visualizer_params) + + def visualize(self, image_bgr, scored_bboxes): + boxes_xywh, box_scores = scored_bboxes + assert len(boxes_xywh) == len( + box_scores + ), "Number of bounding boxes {} should be equal to the number of scores {}".format( + len(boxes_xywh), len(box_scores) + ) + for i, box_xywh in enumerate(boxes_xywh): + score_i = box_scores[i] + image_bgr = self.visualizer_bbox.visualize(image_bgr, box_xywh) + score_txt = "{0:6.4f}".format(score_i) + topleft_xy = box_xywh[0], box_xywh[1] + image_bgr = self.visualizer_score.visualize(image_bgr, score_txt, topleft_xy) + return image_bgr diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_data_points.py b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_data_points.py new file mode 100644 index 0000000..b6839a9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_data_points.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Iterable, Optional, Tuple +import cv2 + +from densepose.structures import DensePoseDataRelative + +from .base import Boxes, Image, MatrixVisualizer, PointsVisualizer + + +class DensePoseDataCoarseSegmentationVisualizer(object): + """ + Visualizer for ground truth segmentation + """ + + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, + cmap=cmap, + val_scale=255.0 / DensePoseDataRelative.N_BODY_PARTS, + alpha=alpha, + ) + + def visualize( + self, + image_bgr: Image, + bbox_densepose_datas: Optional[Tuple[Iterable[Boxes], Iterable[DensePoseDataRelative]]], + ) -> Image: + if bbox_densepose_datas is None: + return image_bgr + for bbox_xywh, densepose_data in zip(*bbox_densepose_datas): + matrix = densepose_data.segm.numpy() + mask = np.zeros(matrix.shape, dtype=np.uint8) + mask[matrix > 0] = 1 + image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh.numpy()) + return image_bgr + + +class DensePoseDataPointsVisualizer(object): + def __init__(self, densepose_data_to_value_fn=None, cmap=cv2.COLORMAP_PARULA, **kwargs): + self.points_visualizer = PointsVisualizer() + self.densepose_data_to_value_fn = densepose_data_to_value_fn + self.cmap = cmap + + def visualize( + self, + image_bgr: Image, + bbox_densepose_datas: Optional[Tuple[Iterable[Boxes], Iterable[DensePoseDataRelative]]], + ) -> Image: + if bbox_densepose_datas is None: + return image_bgr + for bbox_xywh, densepose_data in zip(*bbox_densepose_datas): + x0, y0, w, h = bbox_xywh.numpy() + x = densepose_data.x.numpy() * w / 255.0 + x0 + y = densepose_data.y.numpy() * h / 255.0 + y0 + pts_xy = zip(x, y) + if self.densepose_data_to_value_fn is None: + image_bgr = self.points_visualizer.visualize(image_bgr, pts_xy) + else: + v = self.densepose_data_to_value_fn(densepose_data) + img_colors_bgr = cv2.applyColorMap(v, self.cmap) + colors_bgr = [ + [int(v) for v in img_color_bgr.ravel()] for img_color_bgr in img_colors_bgr + ] + image_bgr = self.points_visualizer.visualize(image_bgr, pts_xy, colors_bgr) + return image_bgr + + +def _densepose_data_u_for_cmap(densepose_data): + u = np.clip(densepose_data.u.numpy(), 0, 1) * 255.0 + return u.astype(np.uint8) + + +def _densepose_data_v_for_cmap(densepose_data): + v = np.clip(densepose_data.v.numpy(), 0, 1) * 255.0 + return v.astype(np.uint8) + + +def _densepose_data_i_for_cmap(densepose_data): + i = ( + np.clip(densepose_data.i.numpy(), 0.0, DensePoseDataRelative.N_PART_LABELS) + * 255.0 + / DensePoseDataRelative.N_PART_LABELS + ) + return i.astype(np.uint8) + + +class DensePoseDataPointsUVisualizer(DensePoseDataPointsVisualizer): + def __init__(self, **kwargs): + super(DensePoseDataPointsUVisualizer, self).__init__( + densepose_data_to_value_fn=_densepose_data_u_for_cmap, **kwargs + ) + + +class DensePoseDataPointsVVisualizer(DensePoseDataPointsVisualizer): + def __init__(self, **kwargs): + super(DensePoseDataPointsVVisualizer, self).__init__( + densepose_data_to_value_fn=_densepose_data_v_for_cmap, **kwargs + ) + + +class DensePoseDataPointsIVisualizer(DensePoseDataPointsVisualizer): + def __init__(self, **kwargs): + super(DensePoseDataPointsIVisualizer, self).__init__( + densepose_data_to_value_fn=_densepose_data_i_for_cmap, **kwargs + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_iuv.py b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_iuv.py new file mode 100644 index 0000000..a32a418 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_iuv.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Optional, Tuple +import cv2 + +from densepose.structures import DensePoseDataRelative + +from ..structures import DensePoseChartPredictorOutput +from .base import Boxes, Image, MatrixVisualizer + + +class DensePoseOutputsVisualizer(object): + def __init__( + self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, to_visualize=None, **kwargs + ): + assert to_visualize in "IUV", "can only visualize IUV" + self.to_visualize = to_visualize + + if self.to_visualize == "I": + val_scale = 255.0 / DensePoseDataRelative.N_PART_LABELS + else: + val_scale = 1.0 + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha + ) + + def visualize( + self, + image_bgr: Image, + dp_output_with_bboxes: Tuple[Optional[DensePoseChartPredictorOutput], Optional[Boxes]], + ) -> Image: + densepose_output, bboxes_xywh = dp_output_with_bboxes + if densepose_output is None or bboxes_xywh is None: + return image_bgr + + assert isinstance( + densepose_output, DensePoseChartPredictorOutput + ), "DensePoseChartPredictorOutput expected, {} encountered".format(type(densepose_output)) + + S = densepose_output.coarse_segm + I = densepose_output.fine_segm # noqa + U = densepose_output.u + V = densepose_output.v + N = S.size(0) + assert N == I.size( + 0 + ), "densepose outputs S {} and I {}" " should have equal first dim size".format( + S.size(), I.size() + ) + assert N == U.size( + 0 + ), "densepose outputs S {} and U {}" " should have equal first dim size".format( + S.size(), U.size() + ) + assert N == V.size( + 0 + ), "densepose outputs S {} and V {}" " should have equal first dim size".format( + S.size(), V.size() + ) + assert N == len( + bboxes_xywh + ), "number of bounding boxes {}" " should be equal to first dim size of outputs {}".format( + len(bboxes_xywh), N + ) + for n in range(N): + Sn = S[n].argmax(dim=0) + In = I[n].argmax(dim=0) * (Sn > 0).long() + segmentation = In.cpu().numpy().astype(np.uint8) + mask = np.zeros(segmentation.shape, dtype=np.uint8) + mask[segmentation > 0] = 1 + bbox_xywh = bboxes_xywh[n] + + if self.to_visualize == "I": + vis = segmentation + elif self.to_visualize in "UV": + U_or_Vn = {"U": U, "V": V}[self.to_visualize][n].cpu().numpy().astype(np.float32) + vis = np.zeros(segmentation.shape, dtype=np.float32) + for partId in range(U_or_Vn.shape[0]): + vis[segmentation == partId] = ( + U_or_Vn[partId][segmentation == partId].clip(0, 1) * 255 + ) + + # pyre-fixme[61]: `vis` may not be initialized here. + image_bgr = self.mask_visualizer.visualize(image_bgr, mask, vis, bbox_xywh) + + return image_bgr + + +class DensePoseOutputsUVisualizer(DensePoseOutputsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="U", **kwargs) + + +class DensePoseOutputsVVisualizer(DensePoseOutputsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="V", **kwargs) + + +class DensePoseOutputsFineSegmentationVisualizer(DensePoseOutputsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="I", **kwargs) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_vertex.py b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_vertex.py new file mode 100644 index 0000000..0dfc4ae --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_outputs_vertex.py @@ -0,0 +1,248 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json +import numpy as np +from functools import lru_cache +from typing import Dict, List, Optional, Tuple +import cv2 +import torch + +from detectron2.utils.file_io import PathManager + +from densepose.modeling import build_densepose_embedder +from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES + +from ..data.utils import get_class_to_mesh_name_mapping +from ..structures import DensePoseEmbeddingPredictorOutput +from ..structures.mesh import create_mesh +from .base import Boxes, Image, MatrixVisualizer +from .densepose_results_textures import get_texture_atlas + + +@lru_cache() +def get_xyz_vertex_embedding(mesh_name: str, device: torch.device): + if mesh_name == "smpl_27554": + embed_path = PathManager.get_local_path( + "https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy" + ) + embed_map, _ = np.load(embed_path, allow_pickle=True) + embed_map = torch.tensor(embed_map).float()[:, 0] + embed_map -= embed_map.min() + embed_map /= embed_map.max() + else: + mesh = create_mesh(mesh_name, device) + embed_map = mesh.vertices.sum(dim=1) + embed_map -= embed_map.min() + embed_map /= embed_map.max() + embed_map = embed_map**2 + + return embed_map + + +class DensePoseOutputsVertexVisualizer(object): + def __init__( + self, + cfg, + inplace=True, + cmap=cv2.COLORMAP_JET, + alpha=0.7, + device="cuda", + default_class=0, + **kwargs, + ): + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, cmap=cmap, val_scale=1.0, alpha=alpha + ) + self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + self.embedder = build_densepose_embedder(cfg) + self.device = torch.device(device) + self.default_class = default_class + + self.mesh_vertex_embeddings = { + mesh_name: self.embedder(mesh_name).to(self.device) + for mesh_name in self.class_to_mesh_name.values() + if self.embedder.has_embeddings(mesh_name) + } + import os + + + def visualize( + self, + image_bgr: Image, + outputs_boxes_xywh_classes: Tuple[ + Optional[DensePoseEmbeddingPredictorOutput], Optional[Boxes], Optional[List[int]] + ], + ) -> List: + if outputs_boxes_xywh_classes[0] is None: + return [] + + S, E, N, bboxes_xywh, pred_classes = self.extract_and_check_outputs_and_boxes( + outputs_boxes_xywh_classes + ) + # print("N: ", N) + + image_bgrs = [] + for n in range(N): + x, y, w, h = bboxes_xywh[n].int().tolist() + if w <= 0 or h <= 0: + continue + + + mesh_name = self.class_to_mesh_name[pred_classes[n]] + closest_vertices, mask = get_closest_vertices_mask_from_ES( + E[[n]], + S[[n]], + h, + w, + self.mesh_vertex_embeddings[mesh_name], + self.device, + ) + + + #print('closest_vertices: ', closest_vertices.shape,closest_vertices.max(),closest_vertices.min()) + # embed_map = get_xyz_vertex_embedding(mesh_name, self.device) + # vis = (embed_map[closest_vertices].clip(0, 1) * 255.0).cpu().numpy() + # print('vis: ', vis.shape, vis.max(), vis.min()) + + + + mask_numpy = mask.cpu().numpy().astype(dtype=np.uint8) + image_bgrs.append(self.mask_visualizer.visualize(image_bgr.copy(), mask_numpy, closest_vertices, [x, y, w, h])) + + + + + return image_bgrs + + def extract_and_check_outputs_and_boxes(self, outputs_boxes_xywh_classes): + + densepose_output, bboxes_xywh, pred_classes = outputs_boxes_xywh_classes + + if pred_classes is None: + pred_classes = [self.default_class] * len(bboxes_xywh) + + assert isinstance( + densepose_output, DensePoseEmbeddingPredictorOutput + ), "DensePoseEmbeddingPredictorOutput expected, {} encountered".format( + type(densepose_output) + ) + + S = densepose_output.coarse_segm + E = densepose_output.embedding + N = S.size(0) + assert N == E.size( + 0 + ), "CSE coarse_segm {} and embeddings {}" " should have equal first dim size".format( + S.size(), E.size() + ) + assert N == len( + bboxes_xywh + ), "number of bounding boxes {}" " should be equal to first dim size of outputs {}".format( + len(bboxes_xywh), N + ) + assert N == len(pred_classes), ( + "number of predicted classes {}" + " should be equal to first dim size of outputs {}".format(len(bboxes_xywh), N) + ) + + return S, E, N, bboxes_xywh, pred_classes + + +def get_texture_atlases(json_str: Optional[str]) -> Optional[Dict[str, Optional[np.ndarray]]]: + """ + json_str is a JSON string representing a mesh_name -> texture_atlas_path dictionary + """ + if json_str is None: + return None + + paths = json.loads(json_str) + return {mesh_name: get_texture_atlas(path) for mesh_name, path in paths.items()} + + +class DensePoseOutputsTextureVisualizer(DensePoseOutputsVertexVisualizer): + def __init__( + self, + cfg, + texture_atlases_dict, + device="cuda", + default_class=0, + **kwargs, + ): + self.embedder = build_densepose_embedder(cfg) + + self.texture_image_dict = {} + self.alpha_dict = {} + + for mesh_name in texture_atlases_dict.keys(): + if texture_atlases_dict[mesh_name].shape[-1] == 4: # Image with alpha channel + self.alpha_dict[mesh_name] = texture_atlases_dict[mesh_name][:, :, -1] / 255.0 + self.texture_image_dict[mesh_name] = texture_atlases_dict[mesh_name][:, :, :3] + else: + self.alpha_dict[mesh_name] = texture_atlases_dict[mesh_name].sum(axis=-1) > 0 + self.texture_image_dict[mesh_name] = texture_atlases_dict[mesh_name] + + self.device = torch.device(device) + self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + self.default_class = default_class + + self.mesh_vertex_embeddings = { + mesh_name: self.embedder(mesh_name).to(self.device) + for mesh_name in self.class_to_mesh_name.values() + } + + def visualize( + self, + image_bgr: Image, + outputs_boxes_xywh_classes: Tuple[ + Optional[DensePoseEmbeddingPredictorOutput], Optional[Boxes], Optional[List[int]] + ], + ) -> Image: + image_target_bgr = image_bgr.copy() + if outputs_boxes_xywh_classes[0] is None: + return image_target_bgr + + S, E, N, bboxes_xywh, pred_classes = self.extract_and_check_outputs_and_boxes( + outputs_boxes_xywh_classes + ) + + meshes = { + p: create_mesh(self.class_to_mesh_name[p], self.device) for p in np.unique(pred_classes) + } + + for n in range(N): + x, y, w, h = bboxes_xywh[n].int().cpu().numpy() + mesh_name = self.class_to_mesh_name[pred_classes[n]] + closest_vertices, mask = get_closest_vertices_mask_from_ES( + E[[n]], + S[[n]], + h, + w, + self.mesh_vertex_embeddings[mesh_name], + self.device, + ) + uv_array = meshes[pred_classes[n]].texcoords[closest_vertices].permute((2, 0, 1)) + uv_array = uv_array.cpu().numpy().clip(0, 1) + textured_image = self.generate_image_with_texture( + image_target_bgr[y : y + h, x : x + w], + uv_array, + mask.cpu().numpy(), + self.class_to_mesh_name[pred_classes[n]], + ) + if textured_image is None: + continue + image_target_bgr[y : y + h, x : x + w] = textured_image + + return image_target_bgr + + def generate_image_with_texture(self, bbox_image_bgr, uv_array, mask, mesh_name): + alpha = self.alpha_dict.get(mesh_name) + texture_image = self.texture_image_dict.get(mesh_name) + if alpha is None or texture_image is None: + return None + U, V = uv_array + x_index = (U * texture_image.shape[1]).astype(int) + y_index = (V * texture_image.shape[0]).astype(int) + local_texture = texture_image[y_index, x_index][mask] + local_alpha = np.expand_dims(alpha[y_index, x_index][mask], -1) + output_image = bbox_image_bgr.copy() + output_image[mask] = output_image[mask] * (1 - local_alpha) + local_texture * local_alpha + return output_image.astype(np.uint8) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results.py b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results.py new file mode 100644 index 0000000..124ed0c --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results.py @@ -0,0 +1,358 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from typing import List, Optional, Tuple +import cv2 +import torch + +from densepose.structures import DensePoseDataRelative + +from ..structures import DensePoseChartResult +from .base import Boxes, Image, MatrixVisualizer + + +class DensePoseResultsVisualizer(object): + def visualize( + self, + image_bgr: Image, + results_and_boxes_xywh: Tuple[Optional[List[DensePoseChartResult]], Optional[Boxes]], + ) -> Image: + densepose_result, boxes_xywh = results_and_boxes_xywh + if densepose_result is None or boxes_xywh is None: + return image_bgr + + boxes_xywh = boxes_xywh.cpu().numpy() + context = self.create_visualization_context(image_bgr) + densepose_result = densepose_result[0:1] + for i, result in enumerate(densepose_result): + iuv_array = torch.cat( + (result.labels[None].type(torch.float32), result.uv * 255.0) + ).type(torch.uint8) + self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh[i]) + image_bgr = self.context_to_image_bgr(context) + return image_bgr + + def create_visualization_context(self, image_bgr: Image): + return image_bgr + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None: + pass + + def context_to_image_bgr(self, context): + return context + + def get_image_bgr_from_context(self, context): + return context + + +class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer): + def __init__( + self, + data_extractor, + segm_extractor, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + alpha=0.7, + val_scale=1.0, + **kwargs, + ): + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha + ) + self.data_extractor = data_extractor + self.segm_extractor = segm_extractor + + def context_to_image_bgr(self, context): + return context + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None: + image_bgr = self.get_image_bgr_from_context(context) + matrix = self.data_extractor(iuv_arr) + segm = self.segm_extractor(iuv_arr) + mask = np.zeros(matrix.shape, dtype=np.uint8) + mask[segm > 0] = 1 + + + image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh) + + +def _extract_i_from_iuvarr(iuv_arr): + return iuv_arr[0, :, :] + + +def _extract_u_from_iuvarr(iuv_arr): + return iuv_arr[1, :, :] + + +def _extract_v_from_iuvarr(iuv_arr): + return iuv_arr[2, :, :] + + +class DensePoseResultsMplContourVisualizer(DensePoseResultsVisualizer): + def __init__(self, levels=10, **kwargs): + self.levels = levels + self.plot_args = kwargs + + def create_visualization_context(self, image_bgr: Image): + import matplotlib.pyplot as plt + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + + context = {} + context["image_bgr"] = image_bgr + dpi = 100 + height_inches = float(image_bgr.shape[0]) / dpi + width_inches = float(image_bgr.shape[1]) / dpi + fig = plt.figure(figsize=(width_inches, height_inches), dpi=dpi) + plt.axes([0, 0, 1, 1]) + plt.axis("off") + context["fig"] = fig + canvas = FigureCanvas(fig) + context["canvas"] = canvas + extent = (0, image_bgr.shape[1], image_bgr.shape[0], 0) + plt.imshow(image_bgr[:, :, ::-1], extent=extent) + return context + + def context_to_image_bgr(self, context): + fig = context["fig"] + w, h = map(int, fig.get_size_inches() * fig.get_dpi()) + canvas = context["canvas"] + canvas.draw() + image_1d = np.fromstring(canvas.tostring_rgb(), dtype="uint8") + image_rgb = image_1d.reshape(h, w, 3) + image_bgr = image_rgb[:, :, ::-1].copy() + return image_bgr + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh: Boxes) -> None: + import matplotlib.pyplot as plt + + u = _extract_u_from_iuvarr(iuv_arr).astype(float) / 255.0 + v = _extract_v_from_iuvarr(iuv_arr).astype(float) / 255.0 + extent = ( + bbox_xywh[0], + bbox_xywh[0] + bbox_xywh[2], + bbox_xywh[1], + bbox_xywh[1] + bbox_xywh[3], + ) + plt.contour(u, self.levels, extent=extent, **self.plot_args) + plt.contour(v, self.levels, extent=extent, **self.plot_args) + + +class DensePoseResultsCustomContourVisualizer(DensePoseResultsVisualizer): + """ + Contour visualization using marching squares + """ + + def __init__(self, levels=10, **kwargs): + # TODO: colormap is hardcoded + cmap = cv2.COLORMAP_PARULA + if isinstance(levels, int): + self.levels = np.linspace(0, 1, levels) + else: + self.levels = levels + if "linewidths" in kwargs: + self.linewidths = kwargs["linewidths"] + else: + self.linewidths = [1] * len(self.levels) + self.plot_args = kwargs + img_colors_bgr = cv2.applyColorMap((self.levels * 255).astype(np.uint8), cmap) + self.level_colors_bgr = [ + [int(v) for v in img_color_bgr.ravel()] for img_color_bgr in img_colors_bgr + ] + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh: Boxes) -> None: + image_bgr = self.get_image_bgr_from_context(context) + segm = _extract_i_from_iuvarr(iuv_arr) + u = _extract_u_from_iuvarr(iuv_arr).astype(float) / 255.0 + v = _extract_v_from_iuvarr(iuv_arr).astype(float) / 255.0 + self._contours(image_bgr, u, segm, bbox_xywh) + self._contours(image_bgr, v, segm, bbox_xywh) + + def _contours(self, image_bgr, arr, segm, bbox_xywh): + for part_idx in range(1, DensePoseDataRelative.N_PART_LABELS + 1): + mask = segm == part_idx + if not np.any(mask): + continue + arr_min = np.amin(arr[mask]) + arr_max = np.amax(arr[mask]) + I, J = np.nonzero(mask) + i0 = np.amin(I) + i1 = np.amax(I) + 1 + j0 = np.amin(J) + j1 = np.amax(J) + 1 + if (j1 == j0 + 1) or (i1 == i0 + 1): + continue + Nw = arr.shape[1] - 1 + Nh = arr.shape[0] - 1 + for level_idx, level in enumerate(self.levels): + if (level < arr_min) or (level > arr_max): + continue + vp = arr[i0:i1, j0:j1] >= level + bin_codes = vp[:-1, :-1] + vp[1:, :-1] * 2 + vp[1:, 1:] * 4 + vp[:-1, 1:] * 8 + mp = mask[i0:i1, j0:j1] + bin_mask_codes = mp[:-1, :-1] + mp[1:, :-1] * 2 + mp[1:, 1:] * 4 + mp[:-1, 1:] * 8 + it = np.nditer(bin_codes, flags=["multi_index"]) + color_bgr = self.level_colors_bgr[level_idx] + linewidth = self.linewidths[level_idx] + while not it.finished: + if (it[0] != 0) and (it[0] != 15): + i, j = it.multi_index + if bin_mask_codes[i, j] != 0: + self._draw_line( + image_bgr, + arr, + mask, + level, + color_bgr, + linewidth, + it[0], + it.multi_index, + bbox_xywh, + Nw, + Nh, + (i0, j0), + ) + it.iternext() + + def _draw_line( + self, + image_bgr, + arr, + mask, + v, + color_bgr, + linewidth, + bin_code, + multi_idx, + bbox_xywh, + Nw, + Nh, + offset, + ): + lines = self._bin_code_2_lines(arr, v, bin_code, multi_idx, Nw, Nh, offset) + x0, y0, w, h = bbox_xywh + x1 = x0 + w + y1 = y0 + h + for line in lines: + x0r, y0r = line[0] + x1r, y1r = line[1] + pt0 = (int(x0 + x0r * (x1 - x0)), int(y0 + y0r * (y1 - y0))) + pt1 = (int(x0 + x1r * (x1 - x0)), int(y0 + y1r * (y1 - y0))) + cv2.line(image_bgr, pt0, pt1, color_bgr, linewidth) + + def _bin_code_2_lines(self, arr, v, bin_code, multi_idx, Nw, Nh, offset): + i0, j0 = offset + i, j = multi_idx + i += i0 + j += j0 + v0, v1, v2, v3 = arr[i, j], arr[i + 1, j], arr[i + 1, j + 1], arr[i, j + 1] + x0i = float(j) / Nw + y0j = float(i) / Nh + He = 1.0 / Nh + We = 1.0 / Nw + if (bin_code == 1) or (bin_code == 14): + a = (v - v0) / (v1 - v0) + b = (v - v0) / (v3 - v0) + pt1 = (x0i, y0j + a * He) + pt2 = (x0i + b * We, y0j) + return [(pt1, pt2)] + elif (bin_code == 2) or (bin_code == 13): + a = (v - v0) / (v1 - v0) + b = (v - v1) / (v2 - v1) + pt1 = (x0i, y0j + a * He) + pt2 = (x0i + b * We, y0j + He) + return [(pt1, pt2)] + elif (bin_code == 3) or (bin_code == 12): + a = (v - v0) / (v3 - v0) + b = (v - v1) / (v2 - v1) + pt1 = (x0i + a * We, y0j) + pt2 = (x0i + b * We, y0j + He) + return [(pt1, pt2)] + elif (bin_code == 4) or (bin_code == 11): + a = (v - v1) / (v2 - v1) + b = (v - v3) / (v2 - v3) + pt1 = (x0i + a * We, y0j + He) + pt2 = (x0i + We, y0j + b * He) + return [(pt1, pt2)] + elif (bin_code == 6) or (bin_code == 9): + a = (v - v0) / (v1 - v0) + b = (v - v3) / (v2 - v3) + pt1 = (x0i, y0j + a * He) + pt2 = (x0i + We, y0j + b * He) + return [(pt1, pt2)] + elif (bin_code == 7) or (bin_code == 8): + a = (v - v0) / (v3 - v0) + b = (v - v3) / (v2 - v3) + pt1 = (x0i + a * We, y0j) + pt2 = (x0i + We, y0j + b * He) + return [(pt1, pt2)] + elif bin_code == 5: + a1 = (v - v0) / (v1 - v0) + b1 = (v - v1) / (v2 - v1) + pt11 = (x0i, y0j + a1 * He) + pt12 = (x0i + b1 * We, y0j + He) + a2 = (v - v0) / (v3 - v0) + b2 = (v - v3) / (v2 - v3) + pt21 = (x0i + a2 * We, y0j) + pt22 = (x0i + We, y0j + b2 * He) + return [(pt11, pt12), (pt21, pt22)] + elif bin_code == 10: + a1 = (v - v0) / (v3 - v0) + b1 = (v - v0) / (v1 - v0) + pt11 = (x0i + a1 * We, y0j) + pt12 = (x0i, y0j + b1 * He) + a2 = (v - v1) / (v2 - v1) + b2 = (v - v3) / (v2 - v3) + pt21 = (x0i + a2 * We, y0j + He) + pt22 = (x0i + We, y0j + b2 * He) + return [(pt11, pt12), (pt21, pt22)] + return [] + + +try: + import matplotlib + + matplotlib.use("Agg") + DensePoseResultsContourVisualizer = DensePoseResultsMplContourVisualizer +except ModuleNotFoundError: + logger = logging.getLogger(__name__) + logger.warning("Could not import matplotlib, using custom contour visualizer") + DensePoseResultsContourVisualizer = DensePoseResultsCustomContourVisualizer + + +class DensePoseResultsFineSegmentationVisualizer(DensePoseMaskedColormapResultsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super(DensePoseResultsFineSegmentationVisualizer, self).__init__( + _extract_i_from_iuvarr, + _extract_i_from_iuvarr, + inplace, + cmap, + alpha, + val_scale=255.0 / DensePoseDataRelative.N_PART_LABELS, + **kwargs, + ) + + +class DensePoseResultsUVisualizer(DensePoseMaskedColormapResultsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super(DensePoseResultsUVisualizer, self).__init__( + _extract_u_from_iuvarr, + _extract_i_from_iuvarr, + inplace, + cmap, + alpha, + val_scale=1.0, + **kwargs, + ) + + +class DensePoseResultsVVisualizer(DensePoseMaskedColormapResultsVisualizer): + def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs): + super(DensePoseResultsVVisualizer, self).__init__( + _extract_v_from_iuvarr, + _extract_i_from_iuvarr, + inplace, + cmap, + alpha, + val_scale=1.0, + **kwargs, + ) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results_textures.py b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results_textures.py new file mode 100644 index 0000000..8b02f2b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/densepose_results_textures.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import List, Optional, Tuple +import torch + +from detectron2.data.detection_utils import read_image + +from ..structures import DensePoseChartResult +from .base import Boxes, Image +from .densepose_results import DensePoseResultsVisualizer + + +def get_texture_atlas(path: Optional[str]) -> Optional[np.ndarray]: + if path is None: + return None + + # Reading images like that downsamples 16-bit images to 8-bit + # If 16-bit images are needed, we can replace that by cv2.imread with the + # cv2.IMREAD_UNCHANGED flag (with cv2 we also need it to keep alpha channels) + # The rest of the pipeline would need to be adapted to 16-bit images too + bgr_image = read_image(path) + rgb_image = np.copy(bgr_image) # Convert BGR -> RGB + rgb_image[:, :, :3] = rgb_image[:, :, 2::-1] # Works with alpha channel + return rgb_image + + +class DensePoseResultsVisualizerWithTexture(DensePoseResultsVisualizer): + """ + texture_atlas: An image, size 6N * 4N, with N * N squares for each of the 24 body parts. + It must follow the grid found at https://github.com/facebookresearch/DensePose/blob/master/DensePoseData/demo_data/texture_atlas_200.png # noqa + For each body part, U is proportional to the x coordinate, and (1 - V) to y + """ + + def __init__(self, texture_atlas, **kwargs): + self.texture_atlas = texture_atlas + self.body_part_size = texture_atlas.shape[0] // 6 + assert self.body_part_size == texture_atlas.shape[1] // 4 + + def visualize( + self, + image_bgr: Image, + results_and_boxes_xywh: Tuple[Optional[List[DensePoseChartResult]], Optional[Boxes]], + ) -> Image: + densepose_result, boxes_xywh = results_and_boxes_xywh + if densepose_result is None or boxes_xywh is None: + return image_bgr + + boxes_xywh = boxes_xywh.int().cpu().numpy() + texture_image, alpha = self.get_texture() + for i, result in enumerate(densepose_result): + iuv_array = torch.cat((result.labels[None], result.uv.clamp(0, 1))) + x, y, w, h = boxes_xywh[i] + bbox_image = image_bgr[y : y + h, x : x + w] + image_bgr[y : y + h, x : x + w] = self.generate_image_with_texture( + texture_image, alpha, bbox_image, iuv_array.cpu().numpy() + ) + return image_bgr + + def get_texture(self): + N = self.body_part_size + texture_image = np.zeros([24, N, N, self.texture_atlas.shape[-1]]) + for i in range(4): + for j in range(6): + texture_image[(6 * i + j), :, :, :] = self.texture_atlas[ + N * j : N * (j + 1), N * i : N * (i + 1), : + ] + + if texture_image.shape[-1] == 4: # Image with alpha channel + alpha = texture_image[:, :, :, -1] / 255.0 + texture_image = texture_image[:, :, :, :3] + else: + alpha = texture_image.sum(axis=-1) > 0 + + return texture_image, alpha + + def generate_image_with_texture(self, texture_image, alpha, bbox_image_bgr, iuv_array): + + I, U, V = iuv_array + generated_image_bgr = bbox_image_bgr.copy() + + for PartInd in range(1, 25): + x, y = np.where(I == PartInd) + x_index = (U[x, y] * (self.body_part_size - 1)).astype(int) + y_index = ((1 - V[x, y]) * (self.body_part_size - 1)).astype(int) + part_alpha = np.expand_dims(alpha[PartInd - 1, y_index, x_index], -1) + generated_image_bgr[I == PartInd] = ( + generated_image_bgr[I == PartInd] * (1 - part_alpha) + + texture_image[PartInd - 1, y_index, x_index] * part_alpha + ) + + return generated_image_bgr.astype(np.uint8) diff --git a/data_processing/detectron2/projects/DensePose/densepose/vis/extractor.py b/data_processing/detectron2/projects/DensePose/densepose/vis/extractor.py new file mode 100644 index 0000000..9297548 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose/vis/extractor.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import List, Optional, Sequence, Tuple +import torch + +from detectron2.layers.nms import batched_nms +from detectron2.structures.instances import Instances + +from densepose.converters import ToChartResultConverterWithConfidences +from densepose.structures import ( + DensePoseChartResultWithConfidences, + DensePoseEmbeddingPredictorOutput, +) +from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer +from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer +from densepose.vis.densepose_results import DensePoseResultsVisualizer + +from .base import CompoundVisualizer + +Scores = Sequence[float] +DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] + + +def extract_scores_from_instances(instances: Instances, select=None): + if instances.has("scores"): + return instances.scores if select is None else instances.scores[select] + return None + + +def extract_boxes_xywh_from_instances(instances: Instances, select=None): + if instances.has("pred_boxes"): + boxes_xywh = instances.pred_boxes.tensor.clone() + boxes_xywh[:, 2] -= boxes_xywh[:, 0] + boxes_xywh[:, 3] -= boxes_xywh[:, 1] + return boxes_xywh if select is None else boxes_xywh[select] + return None + + +def create_extractor(visualizer: object): + """ + Create an extractor for the provided visualizer + """ + if isinstance(visualizer, CompoundVisualizer): + extractors = [create_extractor(v) for v in visualizer.visualizers] + return CompoundExtractor(extractors) + elif isinstance(visualizer, DensePoseResultsVisualizer): + return DensePoseResultExtractor() + elif isinstance(visualizer, ScoredBoundingBoxVisualizer): + return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) + elif isinstance(visualizer, BoundingBoxVisualizer): + return extract_boxes_xywh_from_instances + elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): + return DensePoseOutputsExtractor() + else: + logger = logging.getLogger(__name__) + logger.error(f"Could not create extractor for {visualizer}") + return None + + +class BoundingBoxExtractor(object): + """ + Extracts bounding boxes from instances + """ + + def __call__(self, instances: Instances): + boxes_xywh = extract_boxes_xywh_from_instances(instances) + return boxes_xywh + + +class ScoredBoundingBoxExtractor(object): + """ + Extracts bounding boxes from instances + """ + + def __call__(self, instances: Instances, select=None): + scores = extract_scores_from_instances(instances) + boxes_xywh = extract_boxes_xywh_from_instances(instances) + if (scores is None) or (boxes_xywh is None): + return (boxes_xywh, scores) + if select is not None: + scores = scores[select] + boxes_xywh = boxes_xywh[select] + return (boxes_xywh, scores) + + +class DensePoseResultExtractor(object): + """ + Extracts DensePose chart result with confidences from instances + """ + + def __call__( + self, instances: Instances, select=None + ) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: + if instances.has("pred_densepose") and instances.has("pred_boxes"): + dpout = instances.pred_densepose + boxes_xyxy = instances.pred_boxes + boxes_xywh = extract_boxes_xywh_from_instances(instances) + if select is not None: + dpout = dpout[select] + boxes_xyxy = boxes_xyxy[select] + converter = ToChartResultConverterWithConfidences() + results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] + return results, boxes_xywh + else: + return None, None + + +class DensePoseOutputsExtractor(object): + """ + Extracts DensePose result from instances + """ + + def __call__( + self, + instances: Instances, + select=None, + ) -> Tuple[ + Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] + ]: + if not (instances.has("pred_densepose") and instances.has("pred_boxes")): + return None, None, None + + dpout = instances.pred_densepose + boxes_xyxy = instances.pred_boxes + boxes_xywh = extract_boxes_xywh_from_instances(instances) + + if instances.has("pred_classes"): + classes = instances.pred_classes.tolist() + else: + classes = None + + if select is not None: + dpout = dpout[select] + boxes_xyxy = boxes_xyxy[select] + if classes is not None: + classes = classes[select] + + return dpout, boxes_xywh, classes + + +class CompoundExtractor(object): + """ + Extracts data for CompoundVisualizer + """ + + def __init__(self, extractors): + self.extractors = extractors + + def __call__(self, instances: Instances, select=None): + datas = [] + for extractor in self.extractors: + data = extractor(instances, select) + datas.append(data) + return datas + + +class NmsFilteredExtractor(object): + """ + Extracts data in the format accepted by NmsFilteredVisualizer + """ + + def __init__(self, extractor, iou_threshold): + self.extractor = extractor + self.iou_threshold = iou_threshold + + def __call__(self, instances: Instances, select=None): + scores = extract_scores_from_instances(instances) + boxes_xywh = extract_boxes_xywh_from_instances(instances) + if boxes_xywh is None: + return None + select_local_idx = batched_nms( + boxes_xywh, + scores, + torch.zeros(len(scores), dtype=torch.int32), + iou_threshold=self.iou_threshold, + ).squeeze() + select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) + select_local[select_local_idx] = True + select = select_local if select is None else (select & select_local) + return self.extractor(instances, select=select) + + +class ScoreThresholdedExtractor(object): + """ + Extracts data in the format accepted by ScoreThresholdedVisualizer + """ + + def __init__(self, extractor, min_score): + self.extractor = extractor + self.min_score = min_score + + def __call__(self, instances: Instances, select=None): + scores = extract_scores_from_instances(instances) + if scores is None: + return None + print('in extractor') + select_local = scores > self.min_score + select = select_local if select is None else (select & select_local) + data = self.extractor(instances, select=select) + return data diff --git a/data_processing/detectron2/projects/DensePose/densepose_methods.py b/data_processing/detectron2/projects/DensePose/densepose_methods.py new file mode 100644 index 0000000..0f2f32d --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/densepose_methods.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from scipy.io import loadmat +import scipy.spatial.distance +import os + + +class DensePoseMethods: + def __init__(self): + # + ALP_UV = loadmat(os.path.join(os.path.dirname(__file__), './DensePoseData/UV_Processed.mat')) + self.FaceIndices = np.array(ALP_UV['All_FaceIndices']).squeeze() + self.FacesDensePose = ALP_UV['All_Faces'] - 1 + self.U_norm = ALP_UV['All_U_norm'].squeeze() + self.V_norm = ALP_UV['All_V_norm'].squeeze() + self.All_vertices = ALP_UV['All_vertices'][0] + ## Info to compute symmetries. + self.SemanticMaskSymmetries = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14] + self.Index_Symmetry_List = [1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, + 23]; + UV_symmetry_filename = os.path.join(os.path.dirname(__file__), + './DensePoseData/UV_symmetry_transforms.mat') + self.UV_symmetry_transformations = loadmat(UV_symmetry_filename) + + def get_symmetric_densepose(self, I, U, V, x, y, Mask): + ### This is a function to get the mirror symmetric UV labels. + Labels_sym = np.zeros(I.shape) + U_sym = np.zeros(U.shape) + V_sym = np.zeros(V.shape) + ### + for i in (range(24)): + if i + 1 in I: + Labels_sym[I == (i + 1)] = self.Index_Symmetry_List[i] + jj = np.where(I == (i + 1)) + ### + U_loc = (U[jj] * 255).astype(np.int64) + V_loc = (V[jj] * 255).astype(np.int64) + ### + V_sym[jj] = self.UV_symmetry_transformations['V_transforms'][0, i][V_loc, U_loc] + U_sym[jj] = self.UV_symmetry_transformations['U_transforms'][0, i][V_loc, U_loc] + ## + Mask_flip = np.fliplr(Mask) + Mask_flipped = np.zeros(Mask.shape) + # + for i in (range(14)): + Mask_flipped[Mask_flip == (i + 1)] = self.SemanticMaskSymmetries[i + 1] + # + [y_max, x_max] = Mask_flip.shape + y_sym = y + x_sym = x_max - x + # + return Labels_sym, U_sym, V_sym, x_sym, y_sym, Mask_flipped + + def barycentric_coordinates_exists(self, P0, P1, P2, P): + u = P1 - P0 + v = P2 - P0 + w = P - P0 + # + vCrossW = np.cross(v, w) + vCrossU = np.cross(v, u) + if (np.dot(vCrossW, vCrossU) < 0): + return False; + # + uCrossW = np.cross(u, w) + uCrossV = np.cross(u, v) + # + if (np.dot(uCrossW, uCrossV) < 0): + return False; + # + denom = np.sqrt((uCrossV ** 2).sum()) + r = np.sqrt((vCrossW ** 2).sum()) / denom + t = np.sqrt((uCrossW ** 2).sum()) / denom + # + return ((r <= 1) & (t <= 1) & (r + t <= 1)) + + def barycentric_coordinates(self, P0, P1, P2, P): + u = P1 - P0 + v = P2 - P0 + w = P - P0 + # + vCrossW = np.cross(v, w) + vCrossU = np.cross(v, u) + # + uCrossW = np.cross(u, w) + uCrossV = np.cross(u, v) + # + denom = np.sqrt((uCrossV ** 2).sum()) + r = np.sqrt((vCrossW ** 2).sum()) / denom + t = np.sqrt((uCrossW ** 2).sum()) / denom + # + return (1 - (r + t), r, t) + + def IUV2FBC(self, I_point, U_point, V_point): + P = [U_point, V_point, 0] + FaceIndicesNow = np.where(self.FaceIndices == I_point) + FacesNow = self.FacesDensePose[FaceIndicesNow] + # + P_0 = np.vstack((self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], + np.zeros(self.U_norm[FacesNow][:, 0].shape))).transpose() + P_1 = np.vstack((self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], + np.zeros(self.U_norm[FacesNow][:, 1].shape))).transpose() + P_2 = np.vstack((self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], + np.zeros(self.U_norm[FacesNow][:, 2].shape))).transpose() + # + + for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)): + if (self.barycentric_coordinates_exists(P0, P1, P2, P)): + [bc1, bc2, bc3] = self.barycentric_coordinates(P0, P1, P2, P) + return (FaceIndicesNow[0][i], bc1, bc2, bc3) + # + # If the found UV is not inside any faces, select the vertex that is closest! + # + print('np.array([U_point, V_point])',np.array([U_point, V_point]).shape) + D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_0[:, 0:2]).squeeze() + D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_1[:, 0:2]).squeeze() + D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_2[:, 0:2]).squeeze() + # + minD1 = D1.min() + minD2 = D2.min() + minD3 = D3.min() + # + if ((minD1 < minD2) & (minD1 < minD3)): + return (FaceIndicesNow[0][np.argmin(D1)], 1., 0., 0.) + elif ((minD2 < minD1) & (minD2 < minD3)): + return (FaceIndicesNow[0][np.argmin(D2)], 0., 1., 0.) + else: + return (FaceIndicesNow[0][np.argmin(D3)], 0., 0., 1.) + + def FBC2PointOnSurface(self, FaceIndex, bc1, bc2, bc3, Vertices): + ## + Vert_indices = self.All_vertices[self.FacesDensePose[FaceIndex]] - 1 + ## + p = Vertices[Vert_indices[0], :] * bc1 + \ + Vertices[Vert_indices[1], :] * bc2 + \ + Vertices[Vert_indices[2], :] * bc3 + ## + return (p) diff --git a/data_processing/detectron2/projects/DensePose/dev/README.md b/data_processing/detectron2/projects/DensePose/dev/README.md new file mode 100644 index 0000000..e3a94b6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/dev/README.md @@ -0,0 +1,7 @@ + +## Some scripts for developers to use, include: + +- `run_instant_tests.sh`: run training for a few iterations. +- `run_inference_tests.sh`: run inference on a small dataset. +- `../../dev/linter.sh`: lint the codebase before commit +- `../../dev/parse_results.sh`: parse results from log file. diff --git a/data_processing/detectron2/projects/DensePose/dev/run_inference_tests.sh b/data_processing/detectron2/projects/DensePose/dev/run_inference_tests.sh new file mode 100644 index 0000000..46556b8 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/dev/run_inference_tests.sh @@ -0,0 +1,33 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +BIN="python train_net.py" +OUTPUT="inference_test_output" +NUM_GPUS=2 +IMS_PER_GPU=2 +IMS_PER_BATCH=$(( NUM_GPUS * IMS_PER_GPU )) + +CFG_LIST=( "${@:1}" ) + +if [ ${#CFG_LIST[@]} -eq 0 ]; then + CFG_LIST=( ./configs/quick_schedules/*inference_acc_test.yaml ) +fi + +echo "========================================================================" +echo "Configs to run:" +echo "${CFG_LIST[@]}" +echo "========================================================================" + +for cfg in "${CFG_LIST[@]}"; do + echo "========================================================================" + echo "Running $cfg ..." + echo "========================================================================" + $BIN \ + --eval-only \ + --num-gpus $NUM_GPUS \ + --config-file "$cfg" \ + OUTPUT_DIR "$OUTPUT" \ + SOLVER.IMS_PER_BATCH $IMS_PER_BATCH + rm -rf $OUTPUT +done + diff --git a/data_processing/detectron2/projects/DensePose/dev/run_instant_tests.sh b/data_processing/detectron2/projects/DensePose/dev/run_instant_tests.sh new file mode 100644 index 0000000..23a9c67 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/dev/run_instant_tests.sh @@ -0,0 +1,28 @@ +#!/bin/bash -e +# Copyright (c) Facebook, Inc. and its affiliates. + +BIN="python train_net.py" +OUTPUT="instant_test_output" +NUM_GPUS=2 +SOLVER_IMS_PER_BATCH=$((NUM_GPUS * 2)) + +CFG_LIST=( "${@:1}" ) +if [ ${#CFG_LIST[@]} -eq 0 ]; then + CFG_LIST=( ./configs/quick_schedules/*instant_test.yaml ) +fi + +echo "========================================================================" +echo "Configs to run:" +echo "${CFG_LIST[@]}" +echo "========================================================================" + +for cfg in "${CFG_LIST[@]}"; do + echo "========================================================================" + echo "Running $cfg ..." + echo "========================================================================" + $BIN --num-gpus $NUM_GPUS --config-file "$cfg" \ + SOLVER.IMS_PER_BATCH $SOLVER_IMS_PER_BATCH \ + OUTPUT_DIR "$OUTPUT" + rm -rf "$OUTPUT" +done + diff --git a/data_processing/detectron2/projects/DensePose/doc/BOOTSTRAPPING_PIPELINE.md b/data_processing/detectron2/projects/DensePose/doc/BOOTSTRAPPING_PIPELINE.md new file mode 100644 index 0000000..a132686 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/BOOTSTRAPPING_PIPELINE.md @@ -0,0 +1,197 @@ +# Bootstrapping Pipeline + +Bootstrapping pipeline for DensePose was proposed in +[Sanakoyeu et al., 2020](https://arxiv.org/pdf/2003.00080.pdf) +to extend DensePose from humans to proximal animal classes +(chimpanzees). Currently, the pipeline is only implemented for +[chart-based models](DENSEPOSE_IUV.md). +Bootstrapping proceeds in two steps. + +## Master Model Training + +Master model is trained on data from source domain (humans) +and supporting domain (animals). Instances from the source domain +contain full DensePose annotations (`S`, `I`, `U` and `V`) and +instances from the supporting domain have segmentation annotations only. +To ensure segmentation quality in the target domain, only a subset of +supporting domain classes is included into the training. This is achieved +through category filters, e.g. +(see [configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml](../configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml)): + +``` + WHITELISTED_CATEGORIES: + "base_coco_2017_train": + - 1 # person + - 16 # bird + - 17 # cat + - 18 # dog + - 19 # horse + - 20 # sheep + - 21 # cow + - 22 # elephant + - 23 # bear + - 24 # zebra + - 25 # girafe +``` +The acronym `Atop10P` in config file names indicates that categories are filtered to +only contain top 10 animals and person. + +The training is performed in a *class-agnostic* manner: all instances +are mapped into the same class (person), e.g. +(see [configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml](../configs/evolution/Base-RCNN-FPN-Atop10P_CA.yaml)): + +``` + CATEGORY_MAPS: + "base_coco_2017_train": + "16": 1 # bird -> person + "17": 1 # cat -> person + "18": 1 # dog -> person + "19": 1 # horse -> person + "20": 1 # sheep -> person + "21": 1 # cow -> person + "22": 1 # elephant -> person + "23": 1 # bear -> person + "24": 1 # zebra -> person + "25": 1 # girafe -> person +``` +The acronym `CA` in config file names indicates that the training is class-agnostic. + +## Student Model Training + +Student model is trained on data from source domain (humans), +supporting domain (animals) and target domain (chimpanzees). +Annotations in source and supporting domains are similar to the ones +used for the master model training. +Annotations in target domain are obtained by applying the master model +to images that contain instances from the target category and sampling +sparse annotations from dense results. This process is called *bootstrapping*. +Below we give details on how the bootstrapping pipeline is implemented. + +### Data Loaders + +The central components that enable bootstrapping are +[`InferenceBasedLoader`](../densepose/data/inference_based_loader.py) and +[`CombinedDataLoader`](../densepose/data/combined_loader.py). + +`InferenceBasedLoader` takes images from a data loader, applies a model +to the images, filters the model outputs based on the selected criteria and +samples the filtered outputs to produce annotations. + +`CombinedDataLoader` combines data obtained from the loaders based on specified +ratios. The standard data loader has the default ratio of 1.0, +ratios for bootstrap datasets are specified in the configuration file. +The higher the ratio the higher the probability to include samples from the +particular data loader into a batch. + +Here is an example of the bootstrapping configuration taken from +[`configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform.yaml`](../configs/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform.yaml): +``` +BOOTSTRAP_DATASETS: + - DATASET: "chimpnsee" + RATIO: 1.0 + IMAGE_LOADER: + TYPE: "video_keyframe" + SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 + TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 + BATCH_SIZE: 8 + NUM_WORKERS: 1 + INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 + DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_uniform" + COUNT_PER_CLASS: 8 + FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +BOOTSTRAP_MODEL: + WEIGHTS: https://dl.fbaipublicfiles.com/densepose/evolution/densepose_R_50_FPN_DL_WC1M_3x_Atop10P_CA/217578784/model_final_9fe1cc.pkl +``` + +The above example has one bootstrap dataset (`chimpnsee`). This dataset is registered as +a [VIDEO_LIST](../densepose/data/datasets/chimpnsee.py) dataset, which means that +it consists of a number of videos specified in a text file. For videos there can be +different strategies to sample individual images. Here we use `video_keyframe` strategy +which considers only keyframes; this ensures temporal offset between sampled images and +faster seek operations. We select at most 4 random keyframes in each video: + +``` +SELECT: + STRATEGY: "random_k" + NUM_IMAGES: 4 +``` + +The frames are then resized + +``` +TRANSFORM: + TYPE: "resize" + MIN_SIZE: 800 + MAX_SIZE: 1333 +``` + +and batched using the standard +[PyTorch DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader): + +``` +BATCH_SIZE: 8 +NUM_WORKERS: 1 +``` + +`InferenceBasedLoader` decomposes those batches into batches of size `INPUT_BATCH_SIZE` +and applies the master model specified by `BOOTSTRAP_MODEL`. Models outputs are filtered +by detection score: + +``` +FILTER: + TYPE: "detection_score" + MIN_VALUE: 0.8 +``` + +and sampled using the specified sampling strategy: + +``` +DATA_SAMPLER: + # supported types: + # densepose_uniform + # densepose_UV_confidence + # densepose_fine_segm_confidence + # densepose_coarse_segm_confidence + TYPE: "densepose_uniform" + COUNT_PER_CLASS: 8 +``` + +The current implementation supports +[uniform sampling](../densepose/data/samplers/densepose_uniform.py) and +[confidence-based sampling](../densepose/data/samplers/densepose_confidence_based.py) +to obtain sparse annotations from dense results. For confidence-based +sampling one needs to use the master model which produces confidence estimates. +The `WC1M` master model used in the example above produces all three types of confidence +estimates. + +Finally, sampled data is grouped into batches of size `OUTPUT_BATCH_SIZE`: + +``` +INFERENCE: + INPUT_BATCH_SIZE: 1 + OUTPUT_BATCH_SIZE: 1 +``` + +The proportion of data from annotated datasets and bootstrapped dataset can be tracked +in the logs, e.g.: + +``` +[... densepose.engine.trainer]: batch/ 1.8, batch/base_coco_2017_train 6.4, batch/densepose_coco_2014_train 3.85 +``` + +which means that over the last 20 iterations, on average for 1.8 bootstrapped data samples there were 6.4 samples from `base_coco_2017_train` and 3.85 samples from `densepose_coco_2014_train`. diff --git a/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_CSE.md b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_CSE.md new file mode 100644 index 0000000..d5761ef --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_CSE.md @@ -0,0 +1,336 @@ +# Continuous Surface Embeddings for Dense Pose Estimation for Humans and Animals + +## Overview + +
+ +
+ +The pipeline uses [Faster R-CNN](https://arxiv.org/abs/1506.01497) +with [Feature Pyramid Network](https://arxiv.org/abs/1612.03144) meta architecture +outlined in Figure 1. For each detected object, the model predicts +its coarse segmentation `S` (2 channels: foreground / background) +and the embedding `E` (16 channels). At the same time, the embedder produces vertex +embeddings `Ê` for the corresponding mesh. Universal positional embeddings `E` +and vertex embeddings `Ê` are matched to derive for each pixel its continuous +surface embedding. + +
+ +
+

Figure 1. DensePose continuous surface embeddings architecture based on Faster R-CNN with Feature Pyramid Network (FPN).

+ +### Datasets + +For more details on datasets used for training and validation of +continuous surface embeddings models, +please refer to the [DensePose Datasets](DENSEPOSE_DATASETS.md) page. + +## Model Zoo and Baselines + +### Human CSE Models + +Continuous surface embeddings models for humans trained using the protocols from [Neverova et al, 2020](https://arxiv.org/abs/2011.12438). + +Models trained with hard assignment loss ℒ: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_s1xs1x0.3490.0606.361.167.164.465.7251155172model | metrics
R_101_FPN_s1xs1x0.4610.0717.462.367.264.765.8251155500model | metrics
R_50_FPN_DL_s1xs1x0.3990.0617.060.867.865.566.4251156349model | metrics
R_101_FPN_DL_s1xs1x0.5040.0748.361.568.065.666.6251156606model | metrics
+ +Models trained with soft assignment loss ℒσ: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_soft_s1xs1x0.3570.0579.761.366.964.365.4250533982model | metrics
R_101_FPN_soft_s1xs1x0.4640.07110.562.167.364.566.0250712522model | metrics
R_50_FPN_DL_soft_s1xs1x0.4270.06211.360.868.066.166.7250713703model | metrics
R_101_FPN_DL_soft_s1xs1x0.4830.07112.261.568.266.267.1250713061model | metrics
+ +### Animal CSE Models + +Models obtained by finetuning human CSE models on animals data from `ds1_train` +(see the [DensePose LVIS](DENSEPOSE_DATASETS.md#continuous-surface-embeddings-annotations-3) +section for more details on the datasets) with soft assignment loss ℒσ: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_soft_chimps_finetune_4k4K0.5690.0514.762.059.032.239.6253146869model | metrics
R_50_FPN_soft_animals_finetune_4k4K0.3810.0617.344.955.521.328.8253145793model | metrics
R_50_FPN_soft_animals_CA_finetune_4k4K0.4120.0597.153.459.525.433.4253498611model | metrics
+ +Acronyms: + +`CA`: class agnostic training, where all annotated instances are mapped into a single category + + +Models obtained by finetuning human CSE models on animals data from `ds2_train` dataset +with soft assignment loss ℒσ and, for some schedules, cycle losses. +Please refer to [DensePose LVIS](DENSEPOSE_DATASETS.md#continuous-surface-embeddings-annotations-3) +section for details on the dataset and to [Neverova et al, 2021]() for details on cycle losses. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
GErrGPSmodel iddownload
R_50_FPN_soft_animals_I0_finetune_16k16k0.3860.0588.454.267.029.038.613.285.4270727112model | metrics
R_50_FPN_soft_animals_I0_finetune_m2m_16k16k0.5080.05612.254.167.328.638.412.587.6270982215model | metrics
R_50_FPN_soft_animals_I0_finetune_i2m_16k16k0.4830.0569.754.066.628.938.311.088.9270727461model | metrics
+ +## References + +If you use DensePose methods based on continuous surface embeddings, please take the +references from the following BibTeX entries: + +Continuous surface embeddings: +``` +@InProceedings{Neverova2020ContinuousSurfaceEmbeddings, + title = {Continuous Surface Embeddings}, + author = {Neverova, Natalia and Novotny, David and Khalidov, Vasil and Szafraniec, Marc and Labatut, Patrick and Vedaldi, Andrea}, + journal = {Advances in Neural Information Processing Systems}, + year = {2020}, +} +``` + +Cycle Losses: +``` +@InProceedings{Neverova2021UniversalCanonicalMaps, + title = {Discovering Relationships between Object Categories via Universal Canonical Maps}, + author = {Neverova, Natalia and Sanakoyeu, Artsiom and Novotny, David and Labatut, Patrick and Vedaldi, Andrea}, + journal = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2021}, +} +``` diff --git a/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_DATASETS.md b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_DATASETS.md new file mode 100644 index 0000000..6943741 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_DATASETS.md @@ -0,0 +1,513 @@ +# DensePose Datasets + +We summarize the datasets used in various DensePose training +schedules and describe different available annotation types. + +## Table of Contents + +[General Information](#general-information) + +[DensePose COCO](#densepose-coco) + +[DensePose PoseTrack](#densepose-posetrack) + +[DensePose Chimps](#densepose-chimps) + +[DensePose LVIS](#densepose-lvis) + +## General Information + +DensePose annotations are typically stored in JSON files. Their +structure follows the [COCO Data Format](https://cocodataset.org/#format-data), +the basic data structure is outlined below: + +``` +{ + "info": info, + "images": [image], + "annotations": [annotation], + "licenses": [license], +} + +info{ + "year": int, + "version": str, + "description": str, + "contributor": str, + "url": str, + "date_created": datetime, +} + +image{ + "id": int, + "width": int, + "height": int, + "file_name": str, + "license": int, + "flickr_url": str, + "coco_url": str, + "date_captured": datetime, +} + +license{ + "id": int, "name": str, "url": str, +} +``` + +DensePose annotations can be of two types: +*chart-based annotations* or *continuous surface embeddings annotations*. +We give more details on each of the two annotation types below. + +### Chart-based Annotations + +These annotations assume a single 3D model which corresponds to +all the instances in a given dataset. +3D model is assumed to be split into *charts*. Each chart has its own +2D parametrization through inner coordinates `U` and `V`, typically +taking values in `[0, 1]`. + +Chart-based annotations consist of *point-based annotations* and +*segmentation annotations*. Point-based annotations specify, for a given +image point, which model part it belongs to and what are its coordinates +in the corresponding chart. Segmentation annotations specify regions +in an image that are occupied by a given part. In some cases, charts +associated with point annotations are more detailed than the ones +associated with segmentation annotations. In this case we distinguish +*fine segmentation* (associated with points) and *coarse segmentation* +(associated with masks). + +**Point-based annotations**: + +`dp_x` and `dp_y`: image coordinates of the annotated points along +the horizontal and vertical axes respectively. The coordinates are defined +with respect to the top-left corner of the annotated bounding box and are +normalized assuming the bounding box size to be `256x256`; + +`dp_I`: for each point specifies the index of the fine segmentation chart +it belongs to; + +`dp_U` and `dp_V`: point coordinates on the corresponding chart. +Each fine segmentation part has its own parametrization in terms of chart +coordinates. + +**Segmentation annotations**: + +`dp_masks`: RLE encoded dense masks (`dict` containing keys `counts` and `size`). +The masks are typically of size `256x256`, they define segmentation within the +bounding box. + +### Continuous Surface Embeddings Annotations + +Continuous surface embeddings annotations also consist of *point-based annotations* +and *segmentation annotations*. Point-based annotations establish correspondence +between image points and 3D model vertices. Segmentation annotations specify +foreground regions for a given instane. + +**Point-based annotations**: + +`dp_x` and `dp_y` specify image point coordinates the same way as for chart-based +annotations; + +`dp_vertex` gives indices of 3D model vertices, which the annotated image points +correspond to; + +`ref_model` specifies 3D model name. + +**Segmentation annotations**: + +Segmentations can either be given by `dp_masks` field or by `segmentation` field. + +`dp_masks`: RLE encoded dense masks (`dict` containing keys `counts` and `size`). +The masks are typically of size `256x256`, they define segmentation within the +bounding box. + +`segmentation`: polygon-based masks stored as a 2D list +`[[x1 y1 x2 y2...],[x1 y1 ...],...]` of polygon vertex coordinates in a given +image. + +## DensePose COCO + +
+ +
+

+ Figure 1. Annotation examples from the DensePose COCO dataset. +

+ +DensePose COCO dataset contains about 50K annotated persons on images from the +[COCO dataset](https://cocodataset.org/#home) +The images are available for download from the +[COCO Dataset download page](https://cocodataset.org/#download): +[train2014](http://images.cocodataset.org/zips/train2014.zip), +[val2014](http://images.cocodataset.org/zips/val2014.zip). +The details on available annotations and their download links are given below. + +### Chart-based Annotations + +Chart-based DensePose COCO annotations are available for the instances of category +`person` and correspond to the model shown in Figure 2. +They include `dp_x`, `dp_y`, `dp_I`, `dp_U` and `dp_V` fields for annotated points +(~100 points per annotated instance) and `dp_masks` field, which encodes +coarse segmentation into 14 parts in the following order: +`Torso`, `Right Hand`, `Left Hand`, `Left Foot`, `Right Foot`, +`Upper Leg Right`, `Upper Leg Left`, `Lower Leg Right`, `Lower Leg Left`, +`Upper Arm Left`, `Upper Arm Right`, `Lower Arm Left`, `Lower Arm Right`, +`Head`. + +
+ +
+

+ Figure 2. Human body charts (fine segmentation) + and the associated 14 body parts depicted with rounded rectangles + (coarse segmentation). +

+ +The dataset splits used in the training schedules are +`train2014`, `valminusminival2014` and `minival2014`. +`train2014` and `valminusminival2014` are used for training, +and `minival2014` is used for validation. +The table with annotation download links, which summarizes the number of annotated +instances and images for each of the dataset splits is given below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Name# inst# imagesfile sizedownload
densepose_train20143921026437526Mdensepose_train2014.json
densepose_valminusminival201472975984105Mdensepose_valminusminival2014.json
densepose_minival20142243150831Mdensepose_minival2014.json
+ +### Continuous Surface Embeddings Annotations + +DensePose COCO continuous surface embeddings annotations are available for the instances +of category `person`. The annotations correspond to the 3D model shown in Figure 2, +and include `dp_x`, `dp_y` and `dp_vertex` and `ref_model` fields. +All chart-based annotations were also kept for convenience. + +As with chart-based annotations, the dataset splits used in the training schedules are +`train2014`, `valminusminival2014` and `minival2014`. +`train2014` and `valminusminival2014` are used for training, +and `minival2014` is used for validation. +The table with annotation download links, which summarizes the number of annotated +instances and images for each of the dataset splits is given below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Name# inst# imagesfile sizedownload
densepose_train2014_cse3921026437554Mdensepose_train2014_cse.json
densepose_valminusminival2014_cse72975984110Mdensepose_valminusminival2014_cse.json
densepose_minival2014_cse2243150832Mdensepose_minival2014_cse.json
+ +## DensePose PoseTrack + +
+ +
+

+ Figure 3. Annotation examples from the PoseTrack dataset. +

+ +DensePose PoseTrack dataset contains annotated image sequences. +To download the images for this dataset, please follow the instructions +from the [PoseTrack Download Page](https://posetrack.net/users/download.php). + +### Chart-based Annotations + +Chart-based DensePose PoseTrack annotations are available for the instances with category +`person` and correspond to the model shown in Figure 2. +They include `dp_x`, `dp_y`, `dp_I`, `dp_U` and `dp_V` fields for annotated points +(~100 points per annotated instance) and `dp_masks` field, which encodes +coarse segmentation into the same 14 parts as in DensePose COCO. + +The dataset splits used in the training schedules are +`posetrack_train2017` (train set) and `posetrack_val2017` (validation set). +The table with annotation download links, which summarizes the number of annotated +instances, instance tracks and images for the dataset splits is given below: + + + + + + + + + + + + + + + + + + + + + + + + + + +
Name# inst# images# tracksfile sizedownload
densepose_posetrack_train20178274168036118Mdensepose_posetrack_train2017.json
densepose_posetrack_val201747537824659Mdensepose_posetrack_val2017.json
+ +## DensePose Chimps + +
+ +
+

+ Figure 4. Example images from the DensePose Chimps dataset. +

+ +DensePose Chimps dataset contains annotated images of chimpanzees. +To download the images for this dataset, please use the URL specified in +`image_url` field in the annotations. + +### Chart-based Annotations + +Chart-based DensePose Chimps annotations correspond to the human model shown in Figure 2, +the instances are thus annotated to belong to the `person` category. +They include `dp_x`, `dp_y`, `dp_I`, `dp_U` and `dp_V` fields for annotated points +(~3 points per annotated instance) and `dp_masks` field, which encodes +foreground mask in RLE format. + +Chart-base DensePose Chimps annotations are used for validation only. +The table with annotation download link, which summarizes the number of annotated +instances and images is given below: + + + + + + + + + + + + + + + + + +
Name# inst# imagesfile sizedownload
densepose_chimps9306546Mdensepose_chimps_full_v2.json
+ +### Continuous Surface Embeddings Annotations + +Continuous surface embeddings annotations for DensePose Chimps +include `dp_x`, `dp_y` and `dp_vertex` point-based annotations +(~3 points per annotated instance), `dp_masks` field with the same +contents as for chart-based annotations and `ref_model` field +which refers to a chimpanzee 3D model `chimp_5029`. + +The dataset is split into training and validation subsets. +The table with annotation download links, which summarizes the number of annotated +instances and images for each of the dataset splits is given below: + +The table below outlines the dataset splits: + + + + + + + + + + + + + + + + + + + + + + + +
Name# inst# imagesfile sizedownload
densepose_chimps_cse_train5003503Mdensepose_chimps_cse_train.json
densepose_chimps_cse_val4303043Mdensepose_chimps_cse_val.json
+ +## DensePose LVIS + +
+ +
+

+ Figure 5. Example images from the DensePose LVIS dataset. +

+ +DensePose LVIS dataset contains segmentation and DensePose annotations for animals +on images from the [LVIS dataset](https://www.lvisdataset.org/dataset). +The images are available for download through the links: +[train2017](http://images.cocodataset.org/zips/train2017.zip), +[val2017](http://images.cocodataset.org/zips/val2017.zip). + +### Continuous Surface Embeddings Annotations + +Continuous surface embeddings (CSE) annotations for DensePose LVIS +include `dp_x`, `dp_y` and `dp_vertex` point-based annotations +(~3 points per annotated instance) and a `ref_model` field +which refers to a 3D model that corresponds to the instance. +Instances from 9 animal categories were annotated with CSE DensePose data: +bear, cow, cat, dog, elephant, giraffe, horse, sheep and zebra. + +Foreground masks are available from instance segmentation annotations +(`segmentation` field) in polygon format, they are stored as a 2D list +`[[x1 y1 x2 y2...],[x1 y1 ...],...]`. + +We used two datasets, each constising of one training (`train`) +and validation (`val`) subsets: the first one (`ds1`) +was used in [Neverova et al, 2020](https://arxiv.org/abs/2011.12438). +The second one (`ds2`), was used in [Neverova et al, 2021](). + +The summary of the available datasets is given below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
All DataSelected Animals
(9 categories)
File
Name# cat# img# segm# img# segm# dpsizedownload
ds1_train55641412398541419472518446Mdensepose_lvis_v1_ds1_train_v1.json
ds1_val2515713281571153710365Mdensepose_lvis_v1_ds1_val_v1.json
ds2_train12039938812701411374646964189321051Mdensepose_lvis_v1_ds2_train_v1.json
ds2_val92690915526909155360424Mdensepose_lvis_v1_ds2_val_v1.json
+ +Legend: + +`#cat` - number of categories in the dataset for which annotations are available; + +`#img` - number of images with annotations in the dataset; + +`#segm` - number of segmentation annotations; + +`#dp` - number of DensePose annotations. + + +Important Notes: + +1. The reference models used for `ds1_train` and `ds1_val` are +`bear_4936`, `cow_5002`, `cat_5001`, `dog_5002`, `elephant_5002`, `giraffe_5002`, +`horse_5004`, `sheep_5004` and `zebra_5002`. The reference models used for +`ds2_train` and `ds2_val` are `bear_4936`, `cow_5002`, `cat_7466`, +`dog_7466`, `elephant_5002`, `giraffe_5002`, `horse_5004`, `sheep_5004` and `zebra_5002`. +So reference models for categories `cat` aind `dog` are different for `ds1` and `ds2`. + +2. Some annotations from `ds1_train` are reused in `ds2_train` (4538 DensePose annotations +and 21275 segmentation annotations). The ones for cat and dog categories were remapped +from `cat_5001` and `dog_5002` reference models used in `ds1` to `cat_7466` and `dog_7466` +used in `ds2`. + +3. All annotations from `ds1_val` are included into `ds2_val` after the remapping +procedure mentioned in note 2. + +4. Some annotations from `ds1_train` are part of `ds2_val` (646 DensePose annotations and +1225 segmentation annotations). Thus one should not train on `ds1_train` if evaluating on `ds2_val`. diff --git a/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_IUV.md b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_IUV.md new file mode 100644 index 0000000..de158e0 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/DENSEPOSE_IUV.md @@ -0,0 +1,627 @@ +# Chart-based Dense Pose Estimation for Humans and Animals + +## Overview + +The goal of chart-based DensePose methods is to establish dense correspondences +between image pixels and 3D object mesh by splitting the latter into charts and estimating +for each pixel the corresponding chart index `I` and local chart coordinates `(U, V)`. + +
+ +
+ +The charts used for human DensePose estimation are shown in Figure 1. +The human body is split into 24 parts, each part is parametrized by `U` and `V` +coordinates, each taking values in `[0, 1]`. + +
+ +
+

Figure 1. Partitioning and parametrization of human body surface.

+ +The pipeline uses [Faster R-CNN](https://arxiv.org/abs/1506.01497) +with [Feature Pyramid Network](https://arxiv.org/abs/1612.03144) meta architecture +outlined in Figure 2. For each detected object, the model predicts +its coarse segmentation `S` (2 or 15 channels: foreground / background or +background + 14 predefined body parts), fine segmentation `I` (25 channels: +background + 24 predefined body parts) and local chart coordinates `U` and `V`. + +
+ +
+

Figure 2. DensePose chart-based architecture based on Faster R-CNN with Feature Pyramid Network (FPN).

+ +### Bootstrapping Chart-Based Models + +[Sanakoyeu et al., 2020](https://arxiv.org/pdf/2003.00080.pdf) introduced a pipeline +to transfer DensePose models trained on humans to proximal animal classes (chimpanzees), +which is summarized in Figure 3. The training proceeds in two stages: + +First, a *master* model is trained on data from source domain (humans with full +DensePose annotation `S`, `I`, `U` and `V`) +and supporting domain (animals with segmentation annotation only). +Only selected animal classes are chosen from the supporting +domain through *category filters* to guarantee the quality of target domain results. +The training is done in *class-agnostic manner*: all selected categories are mapped +to a single category (human). + +Second, a *student* model is trained on data from source and supporting domains, +as well as data from target domain obtained by applying the master model, selecting +high-confidence detections and sampling the results. + +
+ +
+

Figure 3. Domain adaptation: master model is trained on data from source and +supporting domains to produce predictions in target domain; student model combines data from source and +supporting domains, as well as sampled predictions from the master model on target domain to improve +target domain predictions quality.

+ +Examples of pretrained master and student models are available in the [Model Zoo](#ModelZooBootstrap). +For more details on the bootstrapping pipeline, please see [Bootstrapping Pipeline](BOOTSTRAPPING_PIPELINE.md). + +### Datasets + +For more details on datasets used for chart-based model training and validation, +please refer to the [DensePose Datasets](DENSEPOSE_DATASETS.md) page. + +## Model Zoo and Baselines + +### Legacy Models + +Baselines trained using schedules from [Güler et al, 2018](https://arxiv.org/pdf/1802.00434.pdf) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_s1x_legacys1x0.3070.0513.258.158.252.154.9164832157model | metrics
R_101_FPN_s1x_legacys1x0.3900.0634.359.559.353.256.0164832182model | metrics
+ +### Improved Baselines, Original Fully Convolutional Head + +These models use an improved training schedule and Panoptic FPN head from [Kirillov et al, 2019](https://arxiv.org/abs/1901.02446). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_s1xs1x0.3590.0664.561.267.263.765.3165712039model | metrics
R_101_FPN_s1xs1x0.4280.0795.862.367.864.566.2165712084model | metrics
+ +### Improved Baselines, DeepLabV3 Head + +These models use an improved training schedule, Panoptic FPN head from [Kirillov et al, 2019](https://arxiv.org/abs/1901.02446) and DeepLabV3 head from [Chen et al, 2017](https://arxiv.org/abs/1706.05587). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_DL_s1xs1x0.3920.0706.761.168.365.666.7165712097model | metrics
R_101_FPN_DL_s1xs1x0.4780.0837.062.368.766.367.6165712116model | metrics
+ +###
Baselines with Confidence Estimation + +These models perform additional estimation of confidence in regressed UV coodrinates, along the lines of [Neverova et al., 2019](https://papers.nips.cc/paper/8378-correlated-uncertainty-for-learning-dense-correspondences-from-noisy-labels). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_WC1_s1xs1x0.3530.0644.660.567.064.265.4173862049model | metrics
R_50_FPN_WC2_s1xs1x0.3640.0664.860.766.964.265.7173861455model | metrics
R_50_FPN_DL_WC1_s1xs1x0.3970.0686.761.168.165.867.0173067973model | metrics
R_50_FPN_DL_WC2_s1xs1x0.4100.0706.860.867.965.666.7173859335model | metrics
R_101_FPN_WC1_s1xs1x0.4350.0765.762.567.664.966.3171402969model | metrics
R_101_FPN_WC2_s1xs1x0.4500.0785.762.367.664.866.4173860702model | metrics
R_101_FPN_DL_WC1_s1xs1x0.4790.0817.962.068.466.267.2173858525model | metrics
R_101_FPN_DL_WC2_s1xs1x0.4910.0827.661.768.365.967.2173294801model | metrics
+ +Acronyms: + +`WC1`: with confidence estimation model type 1 for `U` and `V` + +`WC2`: with confidence estimation model type 2 for `U` and `V` + +###
Baselines with Mask Confidence Estimation + +Models that perform estimation of confidence in regressed UV coodrinates +as well as confidences associated with coarse and fine segmentation, +see [Sanakoyeu et al., 2020](https://arxiv.org/pdf/2003.00080.pdf) for details. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_WC1M_s1xs1x0.3810.0664.860.666.764.065.4217144516model | metrics
R_50_FPN_WC2M_s1xs1x0.3420.0685.060.766.964.265.5216245640model | metrics
R_50_FPN_DL_WC1M_s1xs1x0.3710.0686.060.768.065.266.7216245703model | metrics
R_50_FPN_DL_WC2M_s1xs1x0.3850.0716.160.868.165.066.4216245758model | metrics
R_101_FPN_WC1M_s1xs1x0.4230.0795.962.067.364.866.0216453687model | metrics
R_101_FPN_WC2M_s1xs1x0.4360.0805.962.567.464.566.0216245682model | metrics
R_101_FPN_DL_WC1M_s1xs1x0.4530.0796.862.068.166.467.1216245771model | metrics
R_101_FPN_DL_WC2M_s1xs1x0.4640.0806.961.968.266.167.1216245790model | metrics
+ +Acronyms: + +`WC1M`: with confidence estimation model type 1 for `U` and `V` and mask confidence estimation + +`WC2M`: with confidence estimation model type 2 for `U` and `V` and mask confidence estimation + +###
Bootstrapping Baselines + +Master and student models trained using the bootstrapping pipeline with chimpanzee as the target category, +see [Sanakoyeu et al., 2020](https://arxiv.org/pdf/2003.00080.pdf) +and [Bootstrapping Pipeline](BOOTSTRAPPING_PIPELINE.md) for details. +Evaluation is performed on [DensePose Chimps](DENSEPOSE_DATASETS.md#densepose-chimps) dataset. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namelr
sched
train
time
(s/iter)
inference
time
(s/im)
train
mem
(GB)
box
AP
segm
AP
dp. APex
GPS
dp. AP
GPS
dp. AP
GPSm
model iddownload
R_50_FPN_DL_WC1M_3x_Atop10P_CA3x0.5220.0739.761.359.136.220.030.2217578784model | metrics
R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uniform3x1.9390.07210.160.958.537.221.531.0256453729model | metrics
R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_uv3x1.9850.0729.661.458.938.322.232.1256452095model | metrics
R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_finesegm3x2.0470.07210.360.958.536.720.730.7256452819model | metrics
R_50_FPN_DL_WC1M_3x_Atop10P_CA_B_coarsesegm3x1.8300.0709.661.359.237.921.531.6256455697model | metrics
+ +Acronyms: + +`WC1M`: with confidence estimation model type 1 for `U` and `V` and mask confidence estimation + +`Atop10P`: humans and animals from the 10 best suitable categories are used for training + +`CA`: class agnostic training, where all annotated instances are mapped into a single category + +`B_<...>`: schedule with bootstrapping with the specified results sampling strategy + +Note: + +The relaxed `dp. APex GPS` metric was used in +[Sanakoyeu et al., 2020](https://arxiv.org/pdf/2003.00080.pdf) to evaluate DensePose +results. This metric considers matches at thresholds 0.2, 0.3 and 0.4 additionally +to the standard ones used in the evaluation protocol. The minimum threshold is +controlled by `DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD` config option. + +### License + +All models available for download are licensed under the +[Creative Commons Attribution-ShareAlike 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/) + +## References + +If you use chart-based DensePose methods, please take the references from the following +BibTeX entries: + +DensePose bootstrapping pipeline: +``` +@InProceedings{Sanakoyeu2020TransferringDensePose, + title = {Transferring Dense Pose to Proximal Animal Classes}, + author = {Artsiom Sanakoyeu and Vasil Khalidov and Maureen S. McCarthy and Andrea Vedaldi and Natalia Neverova}, + journal = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2020}, +} +``` + +DensePose with confidence estimation: +``` +@InProceedings{Neverova2019DensePoseConfidences, + title = {Correlated Uncertainty for Learning Dense Correspondences from Noisy Labels}, + author = {Neverova, Natalia and Novotny, David and Vedaldi, Andrea}, + journal = {Advances in Neural Information Processing Systems}, + year = {2019}, +} +``` + +Original DensePose: +``` +@InProceedings{Guler2018DensePose, + title={DensePose: Dense Human Pose Estimation In The Wild}, + author={R\{i}za Alp G\"uler, Natalia Neverova, Iasonas Kokkinos}, + journal={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2018} +} +``` diff --git a/data_processing/detectron2/projects/DensePose/doc/GETTING_STARTED.md b/data_processing/detectron2/projects/DensePose/doc/GETTING_STARTED.md new file mode 100644 index 0000000..a5c86f3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/GETTING_STARTED.md @@ -0,0 +1,76 @@ +# Getting Started with DensePose + +## Inference with Pre-trained Models + +1. Pick a model and its config file from [Model Zoo(IUV)](DENSEPOSE_IUV.md#ModelZoo), [Model Zoo(CSE)](DENSEPOSE_CSE.md#ModelZoo), for example [densepose_rcnn_R_50_FPN_s1x.yaml](../configs/densepose_rcnn_R_50_FPN_s1x.yaml) +2. Run the [Apply Net](TOOL_APPLY_NET.md) tool to visualize the results or save the to disk. For example, to use contour visualization for DensePose, one can run: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml densepose_rcnn_R_50_FPN_s1x.pkl image.jpg dp_contour,bbox --output image_densepose_contour.png +``` +Please see [Apply Net](TOOL_APPLY_NET.md) for more details on the tool. + +## Training + +First, prepare the [dataset](http://densepose.org/#dataset) into the following structure under the directory you'll run training scripts: +
+datasets/coco/
+  annotations/
+    densepose_{train,minival,valminusminival}2014.json
+    densepose_minival2014_100.json   (optional, for testing only)
+  {train,val}2014/
+    # image files that are mentioned in the corresponding json
+
+ +To train a model one can use the [train_net.py](../train_net.py) script. +This script was used to train all DensePose models in [Model Zoo(IUV)](DENSEPOSE_IUV.md#ModelZoo), [Model Zoo(CSE)](DENSEPOSE_CSE.md#ModelZoo). +For example, to launch end-to-end DensePose-RCNN training with ResNet-50 FPN backbone +on 8 GPUs following the s1x schedule, one can run +```bash +python train_net.py --config-file configs/densepose_rcnn_R_50_FPN_s1x.yaml --num-gpus 8 +``` +The configs are made for 8-GPU training. To train on 1 GPU, one can apply the +[linear learning rate scaling rule](https://arxiv.org/abs/1706.02677): +```bash +python train_net.py --config-file configs/densepose_rcnn_R_50_FPN_s1x.yaml \ + SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 +``` + +## Evaluation + +Model testing can be done in the same way as training, except for an additional flag `--eval-only` and +model location specification through `MODEL.WEIGHTS model.pth` in the command line +```bash +python train_net.py --config-file configs/densepose_rcnn_R_50_FPN_s1x.yaml \ + --eval-only MODEL.WEIGHTS model.pth +``` + +## Tools + +We provide tools which allow one to: + - easily view DensePose annotated data in a dataset; + - perform DensePose inference on a set of images; + - visualize DensePose model results; + +`query_db` is a tool to print or visualize DensePose data in a dataset. +Please refer to [Query DB](TOOL_QUERY_DB.md) for more details on this tool + +`apply_net` is a tool to print or visualize DensePose results. +Please refer to [Apply Net](TOOL_APPLY_NET.md) for more details on this tool + + +## Installation as a package + +DensePose can also be installed as a Python package for integration with other software. + +The following dependencies are needed: +- Python >= 3.7 +- [PyTorch](https://pytorch.org/get-started/locally/#start-locally) >= 1.7 (to match [detectron2 requirements](https://detectron2.readthedocs.io/en/latest/tutorials/install.html#requirements)) +- [torchvision](https://pytorch.org/vision/stable/) version [compatible with your version of PyTorch](https://github.com/pytorch/vision#installation) + +DensePose can then be installed from this repository with: + +``` +pip install git+https://github.com/facebookresearch/detectron2@main#subdirectory=projects/DensePose +``` + +After installation, the package will be importable as `densepose`. diff --git a/data_processing/detectron2/projects/DensePose/doc/RELEASE_2020_04.md b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2020_04.md new file mode 100644 index 0000000..2fab6ae --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2020_04.md @@ -0,0 +1,6 @@ +# DensePose Confidence Estimation and Model Zoo Improvements + +* [DensePose models with confidence estimation](doc/DENSEPOSE_IUV.md#ModelZooConfidence) +* [Panoptic FPN and DeepLabV3 head implementation](doc/DENSEPOSE_IUV.md#ModelZooDeepLabV3) +* Test time augmentations for DensePose +* New evaluation metric (GPSm) that yields more reliable scores diff --git a/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_03.md b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_03.md new file mode 100644 index 0000000..eb908a6 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_03.md @@ -0,0 +1,45 @@ +# DensePose CSE and DensePose Evolution + +* [DensePose Evolution pipeline](DENSEPOSE_IUV.md#ModelZooBootstrap), a framework to bootstrap + DensePose on unlabeled data + * [`InferenceBasedLoader`](../densepose/data/inference_based_loader.py) + with data samplers to use inference results from one model + to train another model (bootstrap); + * [`VideoKeyframeDataset`](../densepose/data/video/video_keyframe_dataset.py) + to efficiently load images from video keyframes; + * Category maps and filters to combine annotations from different categories + and train in a class-agnostic manner; + * [Pretrained models](DENSEPOSE_IUV.md#ModelZooBootstrap) for DensePose estimation on chimpanzees; + * DensePose head training from partial data (segmentation only); + * [DensePose models with mask confidence estimation](DENSEPOSE_IUV.md#ModelZooMaskConfidence); + * [DensePose Chimps]() dataset for IUV evaluation +* [DensePose Continuous Surface Embeddings](DENSEPOSE_CSE.md), a framework to extend DensePose + to various categories using 3D models + * [Hard embedding](../densepose/modeling/losses/embed.py) and + [soft embedding](../densepose/modeling/losses/soft_embed.py) + losses to train universal positional embeddings; + * [Embedder](../(densepose/modeling/cse/embedder.py) to handle + mesh vertex embeddings; + * [Storage](../densepose/evaluation/tensor_storage.py) for evaluation with high volumes of data; + * [Pretrained models](DENSEPOSE_CSE.md#ModelZoo) for DensePose CSE estimation on humans and animals; + * [DensePose Chimps](DENSEPOSE_DATASETS.md#densepose-chimps) and + [DensePose LVIS](DENSEPOSE_DATASETS.md#densepose-lvis) datasets for CSE finetuning and evaluation; + * [Vertex and texture mapping visualizers](../densepose/vis/densepose_outputs_vertex.py); +* Refactoring of all major components: losses, predictors, model outputs, model results, visualizers; + * Dedicated structures for [chart outputs](../densepose/structures/chart.py), + [chart outputs with confidences](../densepose/structures/chart_confidence.py), + [chart results](../densepose/structures/chart_result.py), + [CSE outputs](../densepose/structures/cse.py); + * Dedicated predictors for + [chart-based estimation](../densepose/modeling/predictors/chart.py), + [confidence estimation](../densepose/modeling/predictors/chart_confidence.py) + and [CSE estimation](../densepose/modeling/predictors/cse.py); + * Generic handling of various [conversions](../densepose/converters) (e.g. from outputs to results); + * Better organization of various [losses](../densepose/modeling/losses); + * Segregation of loss data accumulators for + [IUV setting](../densepose/modeling/losses/utils.py) + and [CSE setting](../densepose/modeling/losses/embed_utils.py); + * Splitting visualizers into separate modules; +* [HRNet](../densepose/modeling/hrnet.py) and [HRFPN](../densepose/modeling/hrfpn.py) backbones; +* [PoseTrack](DENSEPOSE_DATASETS.md#densepose-posetrack) dataset; +* [IUV texture visualizer](../densepose/vis/densepose_results_textures.py) diff --git a/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_06.md b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_06.md new file mode 100644 index 0000000..fb5ff4f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/RELEASE_2021_06.md @@ -0,0 +1,12 @@ +# DensePose CSE with Cycle Losses + +This release follows the paper [Neverova et al, 2021]() and +adds CSE datasets with more annotations, better CSE animal models +to the model zoo, losses to ensure cycle consistency for models and mesh +alignment evaluator. In particular: + +* [Pixel to shape](../densepose/modeling/losses/cycle_pix2shape.py) and [shape to shape](../densepose/modeling/losses/cycle_shape2shape.py) cycle consistency losses; +* Mesh alignment [evaluator](../densepose/evaluation/mesh_alignment_evaluator.py); +* Existing CSE datasets renamed to [ds1_train](https://dl.fbaipublicfiles.com/densepose/annotations/lvis/densepose_lvis_v1_ds1_train_v1.json) and [ds1_val](https://dl.fbaipublicfiles.com/densepose/annotations/lvis/densepose_lvis_v1_ds1_val_v1.json); +* New CSE datasets [ds2_train](https://dl.fbaipublicfiles.com/densepose/annotations/lvis/densepose_lvis_v1_ds2_train_v1.json) and [ds2_val](https://dl.fbaipublicfiles.com/densepose/annotations/lvis/densepose_lvis_v1_ds2_val_v1.json) added; +* Better CSE animal models trained with the 16k schedule added to the [model zoo](DENSEPOSE_CSE.md#animal-cse-models). diff --git a/data_processing/detectron2/projects/DensePose/doc/TOOL_APPLY_NET.md b/data_processing/detectron2/projects/DensePose/doc/TOOL_APPLY_NET.md new file mode 100644 index 0000000..ca8e1dd --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/TOOL_APPLY_NET.md @@ -0,0 +1,203 @@ +# Apply Net + +`apply_net` is a tool to print or visualize DensePose results on a set of images. +It has two modes: `dump` to save DensePose model results to a pickle file +and `show` to visualize them on images. + +The `image.jpg` file that is used as an example in this doc can be found [here](http://images.cocodataset.org/train2017/000000117508.jpg) + +## Dump Mode + +The general command form is: +```bash +python apply_net.py dump [-h] [-v] [--output ] +``` + +There are three mandatory arguments: + - ``, configuration file for a given model; + - ``, model file with trained parameters + - ``, input image file name, pattern or folder + +One can additionally provide `--output` argument to define the output file name, +which defaults to `output.pkl`. + + +Examples: + +1. Dump results of the [R_50_FPN_s1x](https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl) DensePose model for images in a folder `images` to file `dump.pkl`: +```bash +python apply_net.py dump configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +images --output dump.pkl -v +``` + +2. Dump results of the [R_50_FPN_s1x](https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl) DensePose model for images with file name matching a pattern `image*.jpg` to file `results.pkl`: +```bash +python apply_net.py dump configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +"image*.jpg" --output results.pkl -v +``` + +If you want to load the pickle file generated by the above command: +``` +# make sure DensePose is in your PYTHONPATH, or use the following line to add it: +sys.path.append("/your_detectron2_path/detectron2_repo/projects/DensePose/") + +f = open('/your_result_path/results.pkl', 'rb') +data = pickle.load(f) +``` + +The file `results.pkl` contains the list of results per image, for each image the result is a dictionary. + +**If you use a [IUV model](DENSEPOSE_IUV.md#-model-zoo-and-baselines)**, the dumped data will have the following format: + +``` +data: [{'file_name': '/your_path/image1.jpg', + 'scores': tensor([0.9884]), + 'pred_boxes_XYXY': tensor([[ 69.6114, 0.0000, 706.9797, 706.0000]]), + 'pred_densepose': [DensePoseChartResultWithConfidences(labels=tensor(...), uv=tensor(...), sigma_1=None, + sigma_2=None, kappa_u=None, kappa_v=None, fine_segm_confidence=None, coarse_segm_confidence=None), + DensePoseChartResultWithConfidences, ...] + } + {'file_name': '/your_path/image2.jpg', + 'scores': tensor([0.9999, 0.5373, 0.3991]), + 'pred_boxes_XYXY': tensor([[ 59.5734, 7.7535, 579.9311, 932.3619], + [612.9418, 686.1254, 612.9999, 704.6053], + [164.5081, 407.4034, 598.3944, 920.4266]]), + 'pred_densepose': [DensePoseChartResultWithConfidences(labels=tensor(...), uv=tensor(...), sigma_1=None, + sigma_2=None, kappa_u=None, kappa_v=None, fine_segm_confidence=None, coarse_segm_confidence=None), + DensePoseChartResultWithConfidences, ...] + }] +``` + +`DensePoseChartResultWithConfidences` contains the following fields: +- `labels` - a tensor of size `[H, W]` of type `torch.long` which contains fine segmentation labels (previously called `I`) +- `uv` - a tensor of size `[2, H, W]` of type `torch.float` which contains `U` and `V` coordinates +- various optional confidence-related fields (`sigma_1`, `sigma_2`, `kappa_u`, `kappa_v`, `fine_segm_confidence`, `coarse_segm_confidence`) + + +**If you use a [CSE model](DENSEPOSE_CSE.md#-model-zoo-and-baselines)**, the dumped data will have the following format: +``` +data: [{'file_name': '/your_path/image1.jpg', + 'scores': tensor([0.9984, 0.9961]), + 'pred_boxes_XYXY': tensor([[480.0093, 461.0796, 698.3614, 696.1011], + [78.1589, 168.6614, 307.1287, 653.8522]]), + 'pred_densepose': DensePoseEmbeddingPredictorOutput(embedding=tensor(...), coarse_segm=tensor(...))} + {'file_name': '/your_path/image2.jpg', + 'scores': tensor([0.9189, 0.9491]), + 'pred_boxes_XYXY': tensor([[734.9685, 534.2003, 287.3923, 254.8859], + [434.2853, 765.1219, 132.1029, 867.9283]]), + 'pred_densepose': DensePoseEmbeddingPredictorOutput(embedding=tensor(...), coarse_segm=tensor(...))}] +``` + +`DensePoseEmbeddingPredictorOutput` contains the following fields: +- `embedding` - a tensor of size `[N, D, sz, sz]` of type `torch.float`, which contains embeddings of size `D` of the `N` detections in the image +- `coarse_segm` - a tensor of size `[N, 2, sz, sz]` of type `torch.float` which contains segmentation scores of the `N` detections in the image; e.g. a mask can be obtained by `coarse_segm.argmax(dim=1)` + +`sz` is a fixed size for the tensors; you can resize them to the size of the bounding box, if needed + +We can use the following code, to parse the outputs of the first +detected instance on the first image (IUV model). +``` +img_id, instance_id = 0, 0 # Look at the first image and the first detected instance +bbox_xyxy = data[img_id]['pred_boxes_XYXY'][instance_id] +result = data[img_id]['pred_densepose'][instance_id] +uv = result.uv +``` +The array `bbox_xyxy` contains (x0, y0, x1, y1) of the bounding box. + + +## Visualization Mode + +The general command form is: +```bash +python apply_net.py show [-h] [-v] [--min_score ] [--nms_thresh ] [--output ] +``` + +There are four mandatory arguments: + - ``, configuration file for a given model; + - ``, model file with trained parameters + - ``, input image file name, pattern or folder + - ``, visualizations specifier; currently available visualizations are: + * `bbox` - bounding boxes of detected persons; + * `dp_segm` - segmentation masks for detected persons; + * `dp_u` - each body part is colored according to the estimated values of the + U coordinate in part parameterization; + * `dp_v` - each body part is colored according to the estimated values of the + V coordinate in part parameterization; + * `dp_contour` - plots contours with color-coded U and V coordinates; + * `dp_iuv_texture` - transfers the texture from a given texture image file to detected instances, in IUV mode; + * `dp_vertex` - plots the rainbow visualization of the closest vertices prediction for a given mesh, in CSE mode; + * `dp_cse_texture` - transfers the texture from a given list of texture image files (one from each human or animal mesh) to detected instances, in CSE mode + + +One can additionally provide the following optional arguments: + - `--min_score` to only show detections with sufficient scores that are not lower than provided value + - `--nms_thresh` to additionally apply non-maximum suppression to detections at a given threshold + - `--output` to define visualization file name template, which defaults to `output.png`. + To distinguish output file names for different images, the tool appends 1-based entry index, + e.g. output.0001.png, output.0002.png, etc... +- `--texture_atlas` to define the texture atlas image for IUV texture transfer +- `--texture_atlases_map` to define the texture atlas images map (a dictionary `{mesh name: texture atlas image}`) for CSE texture transfer + + +The following examples show how to output results of a DensePose model +with ResNet-50 FPN backbone using different visualizations for image `image.jpg`: + +1. Show bounding box and segmentation: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +image.jpg bbox,dp_segm -v +``` +![Bounding Box + Segmentation Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_segm.jpg) + +2. Show bounding box and estimated U coordinates for body parts: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +image.jpg bbox,dp_u -v +``` +![Bounding Box + U Coordinate Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_u.jpg) + +3. Show bounding box and estimated V coordinates for body parts: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +image.jpg bbox,dp_v -v +``` +![Bounding Box + V Coordinate Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_v.jpg) + +4. Show bounding box and estimated U and V coordinates via contour plots: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +image.jpg dp_contour,bbox -v +``` +![Bounding Box + Contour Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_contour.jpg) + +5. Show bounding box and texture transfer: +```bash +python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl \ +image.jpg dp_iuv_texture,bbox --texture_atlas texture_from_SURREAL.jpg -v +``` +![Bounding Box + IUV Texture Transfer Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_iuv_texture.jpg) + +6. Show bounding box and CSE rainbow visualization: +```bash +python apply_net.py show configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl \ +image.jpg dp_vertex,bbox -v +``` +![Bounding Box + CSE Rainbow Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_vertex.jpg) + +7. Show bounding box and CSE texture transfer: +```bash +python apply_net.py show configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml \ +https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl \ +image.jpg dp_cse_texture,bbox --texture_atlases_map '{"smpl_27554": "smpl_uvSnapshot_colors.jpg"}' -v +``` +![Bounding Box + CSE Texture Transfer Visualization](https://dl.fbaipublicfiles.com/densepose/web/apply_net/res_bbox_dp_cse_texture.jpg) + +The texture files can be found in the `doc/images` folder diff --git a/data_processing/detectron2/projects/DensePose/doc/TOOL_QUERY_DB.md b/data_processing/detectron2/projects/DensePose/doc/TOOL_QUERY_DB.md new file mode 100644 index 0000000..b0a764b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/doc/TOOL_QUERY_DB.md @@ -0,0 +1,105 @@ + +# Query Dataset + +`query_db` is a tool to print or visualize DensePose data from a dataset. +It has two modes: `print` and `show` to output dataset entries to standard +output or to visualize them on images. + +## Print Mode + +The general command form is: +```bash +python query_db.py print [-h] [-v] [--max-entries N] +``` + +There are two mandatory arguments: + - ``, DensePose dataset specification, from which to select + the entries (e.g. `densepose_coco_2014_train`). + - ``, dataset entry selector which can be a single specification, + or a comma-separated list of specifications of the form + `field[:type]=value` for exact match with the value + or `field[:type]=min-max` for a range of values + +One can additionally limit the maximum number of entries to output +by providing `--max-entries` argument. + +Examples: + +1. Output at most 10 first entries from the `densepose_coco_2014_train` dataset: +```bash +python query_db.py print densepose_coco_2014_train \* --max-entries 10 -v +``` + +2. Output all entries with `file_name` equal to `COCO_train2014_000000000036.jpg`: +```bash +python query_db.py print densepose_coco_2014_train file_name=COCO_train2014_000000000036.jpg -v +``` + +3. Output all entries with `image_id` between 36 and 156: +```bash +python query_db.py print densepose_coco_2014_train image_id:int=36-156 -v +``` + +## Visualization Mode + +The general command form is: +```bash +python query_db.py show [-h] [-v] [--max-entries N] [--output ] +``` + +There are three mandatory arguments: + - ``, DensePose dataset specification, from which to select + the entries (e.g. `densepose_coco_2014_train`). + - ``, dataset entry selector which can be a single specification, + or a comma-separated list of specifications of the form + `field[:type]=value` for exact match with the value + or `field[:type]=min-max` for a range of values + - ``, visualizations specifier; currently available visualizations are: + * `bbox` - bounding boxes of annotated persons; + * `dp_i` - annotated points colored according to the containing part; + * `dp_pts` - annotated points in green color; + * `dp_segm` - segmentation masks for annotated persons; + * `dp_u` - annotated points colored according to their U coordinate in part parameterization; + * `dp_v` - annotated points colored according to their V coordinate in part parameterization; + +One can additionally provide one of the two optional arguments: + - `--max_entries` to limit the maximum number of entries to visualize + - `--output` to provide visualization file name template, which defaults + to `output.png`. To distinguish file names for different dataset + entries, the tool appends 1-based entry index to the output file name, + e.g. output.0001.png, output.0002.png, etc. + +The following examples show how to output different visualizations for image with `id = 322` +from `densepose_coco_2014_train` dataset: + +1. Show bounding box and segmentation: +```bash +python query_db.py show densepose_coco_2014_train image_id:int=322 bbox,dp_segm -v +``` +![Bounding Box + Segmentation Visualization](images/vis_bbox_dp_segm.jpg) + +2. Show bounding box and points colored according to the containing part: +```bash +python query_db.py show densepose_coco_2014_train image_id:int=322 bbox,dp_i -v +``` +![Bounding Box + Point Label Visualization](images/vis_bbox_dp_i.jpg) + +3. Show bounding box and annotated points in green color: +```bash +python query_db.py show densepose_coco_2014_train image_id:int=322 bbox,dp_segm -v +``` +![Bounding Box + Point Visualization](images/vis_bbox_dp_pts.jpg) + +4. Show bounding box and annotated points colored according to their U coordinate in part parameterization: +```bash +python query_db.py show densepose_coco_2014_train image_id:int=322 bbox,dp_u -v +``` +![Bounding Box + Point U Visualization](images/vis_bbox_dp_u.jpg) + +5. Show bounding box and annotated points colored according to their V coordinate in part parameterization: +```bash +python query_db.py show densepose_coco_2014_train image_id:int=322 bbox,dp_v -v +``` +![Bounding Box + Point V Visualization](images/vis_bbox_dp_v.jpg) + + diff --git a/data_processing/detectron2/projects/DensePose/query_db.py b/data_processing/detectron2/projects/DensePose/query_db.py new file mode 100644 index 0000000..814a25f --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/query_db.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +import argparse +import logging +import os +import sys +from timeit import default_timer as timer +from typing import Any, ClassVar, Dict, List +import torch + +from detectron2.data.catalog import DatasetCatalog +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + +from densepose.structures import DensePoseDataRelative +from densepose.utils.dbhelper import EntrySelector +from densepose.utils.logger import verbosity_to_level +from densepose.vis.base import CompoundVisualizer +from densepose.vis.bounding_box import BoundingBoxVisualizer +from densepose.vis.densepose_data_points import ( + DensePoseDataCoarseSegmentationVisualizer, + DensePoseDataPointsIVisualizer, + DensePoseDataPointsUVisualizer, + DensePoseDataPointsVisualizer, + DensePoseDataPointsVVisualizer, +) + +DOC = """Query DB - a tool to print / visualize data from a database +""" + +LOGGER_NAME = "query_db" + +logger = logging.getLogger(LOGGER_NAME) + +_ACTION_REGISTRY: Dict[str, "Action"] = {} + + +class Action(object): + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + parser.add_argument( + "-v", + "--verbosity", + action="count", + help="Verbose mode. Multiple -v options increase the verbosity.", + ) + + +def register_action(cls: type): + """ + Decorator for action classes to automate action registration + """ + global _ACTION_REGISTRY + _ACTION_REGISTRY[cls.COMMAND] = cls + return cls + + +class EntrywiseAction(Action): + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(EntrywiseAction, cls).add_arguments(parser) + parser.add_argument( + "dataset", metavar="", help="Dataset name (e.g. densepose_coco_2014_train)" + ) + parser.add_argument( + "selector", + metavar="", + help="Dataset entry selector in the form field1[:type]=value1[," + "field2[:type]=value_min-value_max...] which selects all " + "entries from the dataset that satisfy the constraints", + ) + parser.add_argument( + "--max-entries", metavar="N", help="Maximum number of entries to process", type=int + ) + + @classmethod + def execute(cls: type, args: argparse.Namespace): + dataset = setup_dataset(args.dataset) + entry_selector = EntrySelector.from_string(args.selector) + context = cls.create_context(args) + if args.max_entries is not None: + for _, entry in zip(range(args.max_entries), dataset): + if entry_selector(entry): + cls.execute_on_entry(entry, context) + else: + for entry in dataset: + if entry_selector(entry): + cls.execute_on_entry(entry, context) + + @classmethod + def create_context(cls: type, args: argparse.Namespace) -> Dict[str, Any]: + context = {} + return context + + +@register_action +class PrintAction(EntrywiseAction): + """ + Print action that outputs selected entries to stdout + """ + + COMMAND: ClassVar[str] = "print" + + @classmethod + def add_parser(cls: type, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser(cls.COMMAND, help="Output selected entries to stdout. ") + cls.add_arguments(parser) + parser.set_defaults(func=cls.execute) + + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(PrintAction, cls).add_arguments(parser) + + @classmethod + def execute_on_entry(cls: type, entry: Dict[str, Any], context: Dict[str, Any]): + import pprint + + printer = pprint.PrettyPrinter(indent=2, width=200, compact=True) + printer.pprint(entry) + + +@register_action +class ShowAction(EntrywiseAction): + """ + Show action that visualizes selected entries on an image + """ + + COMMAND: ClassVar[str] = "show" + VISUALIZERS: ClassVar[Dict[str, object]] = { + "dp_segm": DensePoseDataCoarseSegmentationVisualizer(), + "dp_i": DensePoseDataPointsIVisualizer(), + "dp_u": DensePoseDataPointsUVisualizer(), + "dp_v": DensePoseDataPointsVVisualizer(), + "dp_pts": DensePoseDataPointsVisualizer(), + "bbox": BoundingBoxVisualizer(), + } + + @classmethod + def add_parser(cls: type, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries") + cls.add_arguments(parser) + parser.set_defaults(func=cls.execute) + + @classmethod + def add_arguments(cls: type, parser: argparse.ArgumentParser): + super(ShowAction, cls).add_arguments(parser) + parser.add_argument( + "visualizations", + metavar="", + help="Comma separated list of visualizations, possible values: " + "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))), + ) + parser.add_argument( + "--output", + metavar="", + default="output.png", + help="File name to save output to", + ) + + @classmethod + def execute_on_entry(cls: type, entry: Dict[str, Any], context: Dict[str, Any]): + import cv2 + import numpy as np + + image_fpath = PathManager.get_local_path(entry["file_name"]) + image = cv2.imread(image_fpath, cv2.IMREAD_GRAYSCALE) + image = np.tile(image[:, :, np.newaxis], [1, 1, 3]) + datas = cls._extract_data_for_visualizers_from_entry(context["vis_specs"], entry) + visualizer = context["visualizer"] + image_vis = visualizer.visualize(image, datas) + entry_idx = context["entry_idx"] + 1 + out_fname = cls._get_out_fname(entry_idx, context["out_fname"]) + cv2.imwrite(out_fname, image_vis) + logger.info(f"Output saved to {out_fname}") + context["entry_idx"] += 1 + + @classmethod + def _get_out_fname(cls: type, entry_idx: int, fname_base: str): + base, ext = os.path.splitext(fname_base) + return base + ".{0:04d}".format(entry_idx) + ext + + @classmethod + def create_context(cls: type, args: argparse.Namespace) -> Dict[str, Any]: + vis_specs = args.visualizations.split(",") + visualizers = [] + for vis_spec in vis_specs: + vis = cls.VISUALIZERS[vis_spec] + visualizers.append(vis) + context = { + "vis_specs": vis_specs, + "visualizer": CompoundVisualizer(visualizers), + "out_fname": args.output, + "entry_idx": 0, + } + return context + + @classmethod + def _extract_data_for_visualizers_from_entry( + cls: type, vis_specs: List[str], entry: Dict[str, Any] + ): + dp_list = [] + bbox_list = [] + for annotation in entry["annotations"]: + is_valid, _ = DensePoseDataRelative.validate_annotation(annotation) + if not is_valid: + continue + bbox = torch.as_tensor(annotation["bbox"]) + bbox_list.append(bbox) + dp_data = DensePoseDataRelative(annotation) + dp_list.append(dp_data) + datas = [] + for vis_spec in vis_specs: + datas.append(bbox_list if "bbox" == vis_spec else (bbox_list, dp_list)) + return datas + + +def setup_dataset(dataset_name): + logger.info("Loading dataset {}".format(dataset_name)) + start = timer() + dataset = DatasetCatalog.get(dataset_name) + stop = timer() + logger.info("Loaded dataset {} in {:.3f}s".format(dataset_name, stop - start)) + return dataset + + +def create_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=DOC, + formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120), + ) + parser.set_defaults(func=lambda _: parser.print_help(sys.stdout)) + subparsers = parser.add_subparsers(title="Actions") + for _, action in _ACTION_REGISTRY.items(): + action.add_parser(subparsers) + return parser + + +def main(): + parser = create_argument_parser() + args = parser.parse_args() + verbosity = getattr(args, "verbosity", None) + global logger + logger = setup_logger(name=LOGGER_NAME) + logger.setLevel(verbosity_to_level(verbosity)) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/data_processing/detectron2/projects/DensePose/runmy.py b/data_processing/detectron2/projects/DensePose/runmy.py new file mode 100644 index 0000000..f1c88f3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/runmy.py @@ -0,0 +1,18 @@ +import os +import argparse +# dataset_name = 'pexels' +# for i in range(50): +# path = f'G:/full-head-dataset/{dataset_name}/{i * 1000:08d}' +# +# cmd = f'python apply_net.py show configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml R_101_FPN_DL_soft_s1x.pkl {path}/aligned_images dp_vertex --output {path}/seg --min_score 0.8' +# print(cmd) +# os.system(cmd) + + +dataset_name = 'unsplash' +for i in range(58,64): + path = f'G:/full-head-dataset/{dataset_name}/{i * 1000:08d}' + + cmd = f'python apply_net.py show configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml R_101_FPN_DL_soft_s1x.pkl {path}/aligned_images dp_vertex --output {path}/seg --min_score 0.8' + print(cmd) + os.system(cmd) diff --git a/data_processing/detectron2/projects/DensePose/setup.py b/data_processing/detectron2/projects/DensePose/setup.py new file mode 100644 index 0000000..22ad239 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/setup.py @@ -0,0 +1,42 @@ +import re +from pathlib import Path +from setuptools import find_packages, setup + +try: + import torch # noqa: F401 +except ImportError as e: + raise Exception( + """ +You must install PyTorch prior to installing DensePose: +pip install torch + +For more information: + https://pytorch.org/get-started/locally/ + """ + ) from e + + +def get_detectron2_current_version(): + """Version is not available for import through Python since it is + above the top level of the package. Instead, we parse it from the + file with a regex.""" + # Get version info from detectron2 __init__.py + version_source = (Path(__file__).parents[2] / "detectron2" / "__init__.py").read_text() + version_number = re.findall(r'__version__ = "([0-9\.]+)"', version_source)[0] + return version_number + + +setup( + name="detectron2-densepose", + author="FAIR", + version=get_detectron2_current_version(), + url="https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose", + packages=find_packages(), + python_requires=">=3.7", + install_requires=[ + "av>=8.0.3", + "detectron2@git+https://github.com/facebookresearch/detectron2.git", + "opencv-python-headless>=4.5.3.56", + "scipy>=1.5.4", + ], +) diff --git a/data_processing/detectron2/projects/DensePose/tests/common.py b/data_processing/detectron2/projects/DensePose/tests/common.py new file mode 100644 index 0000000..ff22b9a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/common.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import torch + +from detectron2.config import get_cfg +from detectron2.engine import default_setup +from detectron2.modeling import build_model + +from densepose import add_densepose_config + +_BASE_CONFIG_DIR = "configs" +_EVOLUTION_CONFIG_SUB_DIR = "evolution" +_HRNET_CONFIG_SUB_DIR = "HRNet" +_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules" +_BASE_CONFIG_FILE_PREFIX = "Base-" +_CONFIG_FILE_EXT = ".yaml" + + +def _get_base_config_dir(): + """ + Return the base directory for configurations + """ + return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR) + + +def _get_evolution_config_dir(): + """ + Return the base directory for evolution configurations + """ + return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR) + + +def _get_hrnet_config_dir(): + """ + Return the base directory for HRNet configurations + """ + return os.path.join(_get_base_config_dir(), _HRNET_CONFIG_SUB_DIR) + + +def _get_quick_schedules_config_dir(): + """ + Return the base directory for quick schedules configurations + """ + return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR) + + +def _collect_config_files(config_dir): + """ + Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory + """ + start = _get_base_config_dir() + results = [] + for entry in os.listdir(config_dir): + path = os.path.join(config_dir, entry) + if not os.path.isfile(path): + continue + _, ext = os.path.splitext(entry) + if ext != _CONFIG_FILE_EXT: + continue + if entry.startswith(_BASE_CONFIG_FILE_PREFIX): + continue + config_file = os.path.relpath(path, start) + results.append(config_file) + return results + + +def get_config_files(): + """ + Get all the configuration files (relative to the base configuration directory) + """ + return _collect_config_files(_get_base_config_dir()) + + +def get_evolution_config_files(): + """ + Get all the evolution configuration files (relative to the base configuration directory) + """ + return _collect_config_files(_get_evolution_config_dir()) + + +def get_hrnet_config_files(): + """ + Get all the HRNet configuration files (relative to the base configuration directory) + """ + return _collect_config_files(_get_hrnet_config_dir()) + + +def get_quick_schedules_config_files(): + """ + Get all the quick schedules configuration files (relative to the base configuration directory) + """ + return _collect_config_files(_get_quick_schedules_config_dir()) + + +def get_model_config(config_file): + """ + Load and return the configuration from the specified file (relative to the base configuration + directory) + """ + cfg = get_cfg() + add_densepose_config(cfg) + path = os.path.join(_get_base_config_dir(), config_file) + cfg.merge_from_file(path) + if not torch.cuda.is_available(): + cfg.MODEL.DEVICE = "cpu" + return cfg + + +def get_model(config_file): + """ + Get the model from the specified file (relative to the base configuration directory) + """ + cfg = get_model_config(config_file) + return build_model(cfg) + + +def setup(config_file): + """ + Setup the configuration from the specified file (relative to the base configuration directory) + """ + cfg = get_model_config(config_file) + cfg.freeze() + default_setup(cfg, {}) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_chart_based_annotations_accumulator.py b/data_processing/detectron2/projects/DensePose/tests/test_chart_based_annotations_accumulator.py new file mode 100644 index 0000000..a1c4f85 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_chart_based_annotations_accumulator.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +import torch + +from detectron2.structures import Boxes, BoxMode, Instances + +from densepose.modeling.losses.utils import ChartBasedAnnotationsAccumulator +from densepose.structures import DensePoseDataRelative, DensePoseList + +image_shape = (100, 100) +instances = Instances(image_shape) +n_instances = 3 +instances.proposal_boxes = Boxes(torch.rand(n_instances, 4)) +instances.gt_boxes = Boxes(torch.rand(n_instances, 4)) + + +# instances.gt_densepose = None cannot happen because instances attributes need a length +class TestChartBasedAnnotationsAccumulator(unittest.TestCase): + def test_chart_based_annotations_accumulator_no_gt_densepose(self): + accumulator = ChartBasedAnnotationsAccumulator() + accumulator.accumulate(instances) + expected_values = {"nxt_bbox_with_dp_index": 0, "nxt_bbox_index": n_instances} + for key in accumulator.__dict__: + self.assertEqual(getattr(accumulator, key), expected_values.get(key, [])) + + def test_chart_based_annotations_accumulator_gt_densepose_none(self): + instances.gt_densepose = [None] * n_instances + accumulator = ChartBasedAnnotationsAccumulator() + accumulator.accumulate(instances) + expected_values = {"nxt_bbox_with_dp_index": 0, "nxt_bbox_index": n_instances} + for key in accumulator.__dict__: + self.assertEqual(getattr(accumulator, key), expected_values.get(key, [])) + + def test_chart_based_annotations_accumulator_gt_densepose(self): + data_relative_keys = [ + DensePoseDataRelative.X_KEY, + DensePoseDataRelative.Y_KEY, + DensePoseDataRelative.I_KEY, + DensePoseDataRelative.U_KEY, + DensePoseDataRelative.V_KEY, + DensePoseDataRelative.S_KEY, + ] + annotations = [DensePoseDataRelative({k: [0] for k in data_relative_keys})] * n_instances + instances.gt_densepose = DensePoseList(annotations, instances.gt_boxes, image_shape) + accumulator = ChartBasedAnnotationsAccumulator() + accumulator.accumulate(instances) + bbox_xywh_est = BoxMode.convert( + instances.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + bbox_xywh_gt = BoxMode.convert( + instances.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS + ) + expected_values = { + "s_gt": [ + torch.zeros((3, DensePoseDataRelative.MASK_SIZE, DensePoseDataRelative.MASK_SIZE)) + ] + * n_instances, + "bbox_xywh_est": bbox_xywh_est.split(1), + "bbox_xywh_gt": bbox_xywh_gt.split(1), + "point_bbox_with_dp_indices": [torch.tensor([i]) for i in range(n_instances)], + "point_bbox_indices": [torch.tensor([i]) for i in range(n_instances)], + "bbox_indices": list(range(n_instances)), + "nxt_bbox_with_dp_index": n_instances, + "nxt_bbox_index": n_instances, + } + default_value = [torch.tensor([0])] * 3 + for key in accumulator.__dict__: + to_test = getattr(accumulator, key) + gt_value = expected_values.get(key, default_value) + if key in ["nxt_bbox_with_dp_index", "nxt_bbox_index"]: + self.assertEqual(to_test, gt_value) + elif key == "bbox_indices": + self.assertListEqual(to_test, gt_value) + else: + self.assertTrue(torch.allclose(torch.stack(to_test), torch.stack(gt_value))) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_combine_data_loader.py b/data_processing/detectron2/projects/DensePose/tests/test_combine_data_loader.py new file mode 100644 index 0000000..832903a --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_combine_data_loader.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +import unittest +from typing import Any, Iterable, Iterator, Tuple + +from densepose.data import CombinedDataLoader + + +def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: + """ + Group elements of an iterable by chunks of size `n`, e.g. + grouper(range(9), 4) -> + (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) + """ + it = iter(iterable) + while True: + values = [] + for _ in range(n): + try: + value = next(it) + except StopIteration: + values.extend([fillvalue] * (n - len(values))) + yield tuple(values) + return + values.append(value) + yield tuple(values) + + +class TestCombinedDataLoader(unittest.TestCase): + def test_combine_loaders_1(self): + loader1 = _grouper([f"1_{i}" for i in range(10)], 2) + loader2 = _grouper([f"2_{i}" for i in range(11)], 3) + batch_size = 4 + ratios = (0.1, 0.9) + random.seed(43) + combined = CombinedDataLoader((loader1, loader2), batch_size, ratios) + BATCHES_GT = [ + ["1_0", "1_1", "2_0", "2_1"], + ["2_2", "2_3", "2_4", "2_5"], + ["1_2", "1_3", "2_6", "2_7"], + ["2_8", "2_9", "2_10", None], + ] + for i, batch in enumerate(combined): + self.assertEqual(len(batch), batch_size) + self.assertEqual(batch, BATCHES_GT[i]) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_cse_annotations_accumulator.py b/data_processing/detectron2/projects/DensePose/tests/test_cse_annotations_accumulator.py new file mode 100644 index 0000000..a22dce9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_cse_annotations_accumulator.py @@ -0,0 +1,240 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import unittest +import torch + +from detectron2.structures import Boxes, BoxMode, Instances + +from densepose.modeling.losses.embed_utils import CseAnnotationsAccumulator +from densepose.structures import DensePoseDataRelative, DensePoseList + + +class TestCseAnnotationsAccumulator(unittest.TestCase): + def test_cse_annotations_accumulator_nodp(self): + instances_lst = [ + self._create_instances_nodp(), + ] + self._test_template(instances_lst) + + def test_cse_annotations_accumulator_sparsedp(self): + instances_lst = [ + self._create_instances_sparsedp(), + ] + self._test_template(instances_lst) + + def test_cse_annotations_accumulator_fulldp(self): + instances_lst = [ + self._create_instances_fulldp(), + ] + self._test_template(instances_lst) + + def test_cse_annotations_accumulator_combined(self): + instances_lst = [ + self._create_instances_nodp(), + self._create_instances_sparsedp(), + self._create_instances_fulldp(), + ] + self._test_template(instances_lst) + + def _test_template(self, instances_lst): + acc = CseAnnotationsAccumulator() + for instances in instances_lst: + acc.accumulate(instances) + packed_anns = acc.pack() + self._check_correspondence(packed_anns, instances_lst) + + def _create_instances_nodp(self): + image_shape = (480, 640) + instances = Instances(image_shape) + instances.gt_boxes = Boxes( + torch.as_tensor( + [ + [40.0, 40.0, 140.0, 140.0], + [160.0, 160.0, 270.0, 270.0], + [40.0, 160.0, 160.0, 280.0], + ] + ) + ) + instances.proposal_boxes = Boxes( + torch.as_tensor( + [ + [41.0, 39.0, 142.0, 138.0], + [161.0, 159.0, 272.0, 268.0], + [41.0, 159.0, 162.0, 278.0], + ] + ) + ) + # do not add gt_densepose + return instances + + def _create_instances_sparsedp(self): + image_shape = (540, 720) + instances = Instances(image_shape) + instances.gt_boxes = Boxes( + torch.as_tensor( + [ + [50.0, 50.0, 130.0, 130.0], + [150.0, 150.0, 240.0, 240.0], + [50.0, 150.0, 230.0, 330.0], + ] + ) + ) + instances.proposal_boxes = Boxes( + torch.as_tensor( + [ + [49.0, 51.0, 131.0, 129.0], + [151.0, 149.0, 241.0, 239.0], + [51.0, 149.0, 232.0, 329.0], + ] + ) + ) + instances.gt_densepose = DensePoseList( + [ + None, + self._create_dp_data( + { + "dp_x": [81.69, 153.47, 151.00], + "dp_y": [162.24, 128.71, 113.81], + "dp_vertex": [0, 1, 2], + "ref_model": "zebra_5002", + "dp_masks": [], + }, + {"c": (166, 133), "r": 64}, + ), + None, + ], + instances.gt_boxes, + image_shape, + ) + return instances + + def _create_instances_fulldp(self): + image_shape = (680, 840) + instances = Instances(image_shape) + instances.gt_boxes = Boxes( + torch.as_tensor( + [ + [65.0, 55.0, 165.0, 155.0], + [170.0, 175.0, 275.0, 280.0], + [55.0, 165.0, 165.0, 275.0], + ] + ) + ) + instances.proposal_boxes = Boxes( + torch.as_tensor( + [ + [66.0, 54.0, 166.0, 154.0], + [171.0, 174.0, 276.0, 279.0], + [56.0, 164.0, 166.0, 274.0], + ] + ) + ) + instances.gt_densepose = DensePoseList( + [ + self._create_dp_data( + { + "dp_x": [149.99, 198.62, 157.59], + "dp_y": [170.74, 197.73, 123.12], + "dp_vertex": [3, 4, 5], + "ref_model": "cat_5001", + "dp_masks": [], + }, + {"c": (100, 100), "r": 50}, + ), + self._create_dp_data( + { + "dp_x": [234.53, 116.72, 71.66], + "dp_y": [107.53, 11.31, 142.32], + "dp_vertex": [6, 7, 8], + "ref_model": "dog_5002", + "dp_masks": [], + }, + {"c": (200, 150), "r": 40}, + ), + self._create_dp_data( + { + "dp_x": [225.54, 202.61, 135.90], + "dp_y": [167.46, 181.00, 211.47], + "dp_vertex": [9, 10, 11], + "ref_model": "elephant_5002", + "dp_masks": [], + }, + {"c": (100, 200), "r": 45}, + ), + ], + instances.gt_boxes, + image_shape, + ) + return instances + + def _create_dp_data(self, anns, blob_def=None): + dp_data = DensePoseDataRelative(anns) + if blob_def is not None: + dp_data.segm[ + blob_def["c"][0] - blob_def["r"] : blob_def["c"][0] + blob_def["r"], + blob_def["c"][1] - blob_def["r"] : blob_def["c"][1] + blob_def["r"], + ] = 1 + return dp_data + + def _check_correspondence(self, packed_anns, instances_lst): + instance_idx = 0 + data_idx = 0 + pt_offset = 0 + if packed_anns is not None: + bbox_xyxy_gt = BoxMode.convert( + packed_anns.bbox_xywh_gt.clone(), BoxMode.XYWH_ABS, BoxMode.XYXY_ABS + ) + bbox_xyxy_est = BoxMode.convert( + packed_anns.bbox_xywh_est.clone(), BoxMode.XYWH_ABS, BoxMode.XYXY_ABS + ) + for instances in instances_lst: + if not hasattr(instances, "gt_densepose"): + instance_idx += len(instances) + continue + for i, dp_data in enumerate(instances.gt_densepose): + if dp_data is None: + instance_idx += 1 + continue + n_pts = len(dp_data.x) + self.assertTrue( + torch.allclose(dp_data.x, packed_anns.x_gt[pt_offset : pt_offset + n_pts]) + ) + self.assertTrue( + torch.allclose(dp_data.y, packed_anns.y_gt[pt_offset : pt_offset + n_pts]) + ) + self.assertTrue(torch.allclose(dp_data.segm, packed_anns.coarse_segm_gt[data_idx])) + self.assertTrue( + torch.allclose( + torch.ones(n_pts, dtype=torch.long) * dp_data.mesh_id, + packed_anns.vertex_mesh_ids_gt[pt_offset : pt_offset + n_pts], + ) + ) + self.assertTrue( + torch.allclose( + dp_data.vertex_ids, packed_anns.vertex_ids_gt[pt_offset : pt_offset + n_pts] + ) + ) + self.assertTrue( + torch.allclose(instances.gt_boxes.tensor[i], bbox_xyxy_gt[data_idx]) + ) + self.assertTrue( + torch.allclose(instances.proposal_boxes.tensor[i], bbox_xyxy_est[data_idx]) + ) + self.assertTrue( + torch.allclose( + torch.ones(n_pts, dtype=torch.long) * data_idx, + packed_anns.point_bbox_with_dp_indices[pt_offset : pt_offset + n_pts], + ) + ) + self.assertTrue( + torch.allclose( + torch.ones(n_pts, dtype=torch.long) * instance_idx, + packed_anns.point_bbox_indices[pt_offset : pt_offset + n_pts], + ) + ) + self.assertEqual(instance_idx, packed_anns.bbox_indices[data_idx]) + pt_offset += n_pts + instance_idx += 1 + data_idx += 1 + if data_idx == 0: + self.assertIsNone(packed_anns) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_dataset_loaded_annotations.py b/data_processing/detectron2/projects/DensePose/tests/test_dataset_loaded_annotations.py new file mode 100644 index 0000000..cf8035b --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_dataset_loaded_annotations.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import unittest + +from densepose.data.datasets.builtin import COCO_DATASETS, DENSEPOSE_ANNOTATIONS_DIR, LVIS_DATASETS +from densepose.data.datasets.coco import load_coco_json +from densepose.data.datasets.lvis import load_lvis_json +from densepose.data.utils import maybe_prepend_base_path +from densepose.structures import DensePoseDataRelative + + +class TestDatasetLoadedAnnotations(unittest.TestCase): + COCO_DATASET_DATA = { + "densepose_coco_2014_train": {"n_instances": 39210}, + "densepose_coco_2014_minival": {"n_instances": 2243}, + "densepose_coco_2014_minival_100": {"n_instances": 164}, + "densepose_coco_2014_valminusminival": {"n_instances": 7297}, + "densepose_coco_2014_train_cse": {"n_instances": 39210}, + "densepose_coco_2014_minival_cse": {"n_instances": 2243}, + "densepose_coco_2014_minival_100_cse": {"n_instances": 164}, + "densepose_coco_2014_valminusminival_cse": {"n_instances": 7297}, + "densepose_chimps": {"n_instances": 930}, + "posetrack2017_train": {"n_instances": 8274}, + "posetrack2017_val": {"n_instances": 4753}, + "lvis_v05_train": {"n_instances": 5186}, + "lvis_v05_val": {"n_instances": 1037}, + } + + LVIS_DATASET_DATA = { + "densepose_lvis_v1_train1": {"n_instances": 3394}, + "densepose_lvis_v1_train2": {"n_instances": 1800}, + "densepose_lvis_v1_val": {"n_instances": 1037}, + "densepose_lvis_v1_val_animals_100": {"n_instances": 89}, + } + + def generic_coco_test(self, dataset_info): + if dataset_info.name not in self.COCO_DATASET_DATA: + return + n_inst = self.COCO_DATASET_DATA[dataset_info.name]["n_instances"] + self.generic_test(dataset_info, n_inst, load_coco_json) + + def generic_lvis_test(self, dataset_info): + if dataset_info.name not in self.LVIS_DATASET_DATA: + return + n_inst = self.LVIS_DATASET_DATA[dataset_info.name]["n_instances"] + self.generic_test(dataset_info, n_inst, load_lvis_json) + + def generic_test(self, dataset_info, n_inst, loader_fun): + datasets_root = DENSEPOSE_ANNOTATIONS_DIR + annotations_fpath = maybe_prepend_base_path(datasets_root, dataset_info.annotations_fpath) + images_root = maybe_prepend_base_path(datasets_root, dataset_info.images_root) + image_annotation_dicts = loader_fun( + annotations_json_file=annotations_fpath, + image_root=images_root, + dataset_name=dataset_info.name, + ) + num_valid = sum( + 1 + for image_annotation_dict in image_annotation_dicts + for ann in image_annotation_dict["annotations"] + if DensePoseDataRelative.validate_annotation(ann)[0] + ) + self.assertEqual(num_valid, n_inst) + + +def coco_test_fun(dataset_info): + return lambda self: self.generic_coco_test(dataset_info) + + +for dataset_info in COCO_DATASETS: + setattr( + TestDatasetLoadedAnnotations, + f"test_coco_builtin_loaded_annotations_{dataset_info.name}", + coco_test_fun(dataset_info), + ) + + +def lvis_test_fun(dataset_info): + return lambda self: self.generic_lvis_test(dataset_info) + + +for dataset_info in LVIS_DATASETS: + setattr( + TestDatasetLoadedAnnotations, + f"test_lvis_builtin_loaded_annotations_{dataset_info.name}", + lvis_test_fun(dataset_info), + ) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_frame_selector.py b/data_processing/detectron2/projects/DensePose/tests/test_frame_selector.py new file mode 100644 index 0000000..65f05f5 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_frame_selector.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import random +import unittest + +from densepose.data.video import FirstKFramesSelector, LastKFramesSelector, RandomKFramesSelector + + +class TestFrameSelector(unittest.TestCase): + def test_frame_selector_random_k_1(self): + _SEED = 43 + _K = 4 + random.seed(_SEED) + selector = RandomKFramesSelector(_K) + frame_tss = list(range(0, 20, 2)) + _SELECTED_GT = [0, 8, 4, 6] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) + + def test_frame_selector_random_k_2(self): + _SEED = 43 + _K = 10 + random.seed(_SEED) + selector = RandomKFramesSelector(_K) + frame_tss = list(range(0, 6, 2)) + _SELECTED_GT = [0, 2, 4] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) + + def test_frame_selector_first_k_1(self): + _K = 4 + selector = FirstKFramesSelector(_K) + frame_tss = list(range(0, 20, 2)) + _SELECTED_GT = frame_tss[:_K] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) + + def test_frame_selector_first_k_2(self): + _K = 10 + selector = FirstKFramesSelector(_K) + frame_tss = list(range(0, 6, 2)) + _SELECTED_GT = frame_tss[:_K] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) + + def test_frame_selector_last_k_1(self): + _K = 4 + selector = LastKFramesSelector(_K) + frame_tss = list(range(0, 20, 2)) + _SELECTED_GT = frame_tss[-_K:] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) + + def test_frame_selector_last_k_2(self): + _K = 10 + selector = LastKFramesSelector(_K) + frame_tss = list(range(0, 6, 2)) + _SELECTED_GT = frame_tss[-_K:] + selected = selector(frame_tss) + self.assertEqual(_SELECTED_GT, selected) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_image_list_dataset.py b/data_processing/detectron2/projects/DensePose/tests/test_image_list_dataset.py new file mode 100644 index 0000000..7932602 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_image_list_dataset.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +import os +import tempfile +import unittest +import torch +from torchvision.utils import save_image + +from densepose.data.image_list_dataset import ImageListDataset +from densepose.data.transform import ImageResizeTransform + + +@contextlib.contextmanager +def temp_image(height, width): + random_image = torch.rand(height, width) + with tempfile.NamedTemporaryFile(suffix=".jpg") as f: + f.close() + save_image(random_image, f.name) + yield f.name + os.unlink(f.name) + + +class TestImageListDataset(unittest.TestCase): + def test_image_list_dataset(self): + height, width = 720, 1280 + with temp_image(height, width) as image_fpath: + image_list = [image_fpath] + category_list = [None] + dataset = ImageListDataset(image_list, category_list) + self.assertEqual(len(dataset), 1) + data1, categories1 = dataset[0]["images"], dataset[0]["categories"] + self.assertEqual(data1.shape, torch.Size((1, 3, height, width))) + self.assertEqual(data1.dtype, torch.float32) + self.assertIsNone(categories1[0]) + + def test_image_list_dataset_with_transform(self): + height, width = 720, 1280 + with temp_image(height, width) as image_fpath: + image_list = [image_fpath] + category_list = [None] + transform = ImageResizeTransform() + dataset = ImageListDataset(image_list, category_list, transform) + self.assertEqual(len(dataset), 1) + data1, categories1 = dataset[0]["images"], dataset[0]["categories"] + self.assertEqual(data1.shape, torch.Size((1, 3, 749, 1333))) + self.assertEqual(data1.dtype, torch.float32) + self.assertIsNone(categories1[0]) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_image_resize_transform.py b/data_processing/detectron2/projects/DensePose/tests/test_image_resize_transform.py new file mode 100644 index 0000000..01c3373 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_image_resize_transform.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +import torch + +from densepose.data.transform import ImageResizeTransform + + +class TestImageResizeTransform(unittest.TestCase): + def test_image_resize_1(self): + images_batch = torch.ones((3, 3, 100, 100), dtype=torch.uint8) * 100 + transform = ImageResizeTransform() + images_transformed = transform(images_batch) + IMAGES_GT = torch.ones((3, 3, 800, 800), dtype=torch.float) * 100 + self.assertEqual(images_transformed.size(), IMAGES_GT.size()) + self.assertAlmostEqual(torch.abs(IMAGES_GT - images_transformed).max().item(), 0.0) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_model_e2e.py b/data_processing/detectron2/projects/DensePose/tests/test_model_e2e.py new file mode 100644 index 0000000..055fadf --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_model_e2e.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +import torch + +from detectron2.structures import BitMasks, Boxes, Instances + +from .common import get_model + + +# TODO(plabatut): Modularize detectron2 tests and re-use +def make_model_inputs(image, instances=None): + if instances is None: + return {"image": image} + + return {"image": image, "instances": instances} + + +def make_empty_instances(h, w): + instances = Instances((h, w)) + instances.gt_boxes = Boxes(torch.rand(0, 4)) + instances.gt_classes = torch.tensor([]).to(dtype=torch.int64) + instances.gt_masks = BitMasks(torch.rand(0, h, w)) + return instances + + +class ModelE2ETest(unittest.TestCase): + CONFIG_PATH = "" + + def setUp(self): + self.model = get_model(self.CONFIG_PATH) + + def _test_eval(self, sizes): + inputs = [make_model_inputs(torch.rand(3, size[0], size[1])) for size in sizes] + self.model.eval() + self.model(inputs) + + +class DensePoseRCNNE2ETest(ModelE2ETest): + CONFIG_PATH = "densepose_rcnn_R_101_FPN_s1x.yaml" + + def test_empty_data(self): + self._test_eval([(200, 250), (200, 249)]) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_setup.py b/data_processing/detectron2/projects/DensePose/tests/test_setup.py new file mode 100644 index 0000000..165a1b9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_setup.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from .common import ( + get_config_files, + get_evolution_config_files, + get_hrnet_config_files, + get_quick_schedules_config_files, + setup, +) + + +class TestSetup(unittest.TestCase): + def _test_setup(self, config_file): + setup(config_file) + + def test_setup_configs(self): + config_files = get_config_files() + for config_file in config_files: + self._test_setup(config_file) + + def test_setup_evolution_configs(self): + config_files = get_evolution_config_files() + for config_file in config_files: + self._test_setup(config_file) + + def test_setup_hrnet_configs(self): + config_files = get_hrnet_config_files() + for config_file in config_files: + self._test_setup(config_file) + + def test_setup_quick_schedules_configs(self): + config_files = get_quick_schedules_config_files() + for config_file in config_files: + self._test_setup(config_file) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_structures.py b/data_processing/detectron2/projects/DensePose/tests/test_structures.py new file mode 100644 index 0000000..54082d3 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_structures.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from densepose.structures import normalized_coords_transform + + +class TestStructures(unittest.TestCase): + def test_normalized_coords_transform(self): + bbox = (32, 24, 288, 216) + x0, y0, w, h = bbox + xmin, ymin, xmax, ymax = x0, y0, x0 + w, y0 + h + f = normalized_coords_transform(*bbox) + # Top-left + expected_p, actual_p = (-1, -1), f((xmin, ymin)) + self.assertEqual(expected_p, actual_p) + # Top-right + expected_p, actual_p = (1, -1), f((xmax, ymin)) + self.assertEqual(expected_p, actual_p) + # Bottom-left + expected_p, actual_p = (-1, 1), f((xmin, ymax)) + self.assertEqual(expected_p, actual_p) + # Bottom-right + expected_p, actual_p = (1, 1), f((xmax, ymax)) + self.assertEqual(expected_p, actual_p) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_tensor_storage.py b/data_processing/detectron2/projects/DensePose/tests/test_tensor_storage.py new file mode 100644 index 0000000..aeeeffa --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_tensor_storage.py @@ -0,0 +1,256 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import io +import tempfile +import unittest +from contextlib import ExitStack +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from detectron2.utils import comm + +from densepose.evaluation.tensor_storage import ( + SingleProcessFileTensorStorage, + SingleProcessRamTensorStorage, + SizeData, + storage_gather, +) + + +class TestSingleProcessRamTensorStorage(unittest.TestCase): + def test_read_write_1(self): + schema = { + "tf": SizeData(dtype="float32", shape=(112, 112)), + "ti": SizeData(dtype="int32", shape=(4, 64, 64)), + } + # generate data which corresponds to the schema + data_elts = [] + torch.manual_seed(23) + for _i in range(3): + data_elt = { + "tf": torch.rand((112, 112), dtype=torch.float32), + "ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), + } + data_elts.append(data_elt) + storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) + # write data to the storage + for i in range(3): + record_id = storage.put(data_elts[i]) + self.assertEqual(record_id, i) + # read data from the storage + for i in range(3): + record = storage.get(i) + self.assertEqual(len(record), len(schema)) + for field_name in schema: + self.assertTrue(field_name in record) + self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) + self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) + self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) + + +class TestSingleProcessFileTensorStorage(unittest.TestCase): + def test_read_write_1(self): + schema = { + "tf": SizeData(dtype="float32", shape=(112, 112)), + "ti": SizeData(dtype="int32", shape=(4, 64, 64)), + } + # generate data which corresponds to the schema + data_elts = [] + torch.manual_seed(23) + for _i in range(3): + data_elt = { + "tf": torch.rand((112, 112), dtype=torch.float32), + "ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), + } + data_elts.append(data_elt) + # WARNING: opens the file several times! may not work on all platforms + with tempfile.NamedTemporaryFile() as hFile: + storage = SingleProcessFileTensorStorage(schema, hFile.name, "wb") + # write data to the storage + for i in range(3): + record_id = storage.put(data_elts[i]) + self.assertEqual(record_id, i) + hFile.seek(0) + storage = SingleProcessFileTensorStorage(schema, hFile.name, "rb") + # read data from the storage + for i in range(3): + record = storage.get(i) + self.assertEqual(len(record), len(schema)) + for field_name in schema: + self.assertTrue(field_name in record) + self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) + self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) + self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) + + +def _find_free_port(): + """ + Copied from detectron2/engine/launch.py + """ + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch(main_func, nprocs, args=()): + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + # dist_url = "env://" + mp.spawn( + distributed_worker, nprocs=nprocs, args=(main_func, nprocs, dist_url, args), daemon=False + ) + + +def distributed_worker(local_rank, main_func, nprocs, dist_url, args): + dist.init_process_group( + backend="gloo", init_method=dist_url, world_size=nprocs, rank=local_rank + ) + comm.synchronize() + assert comm._LOCAL_PROCESS_GROUP is None + pg = dist.new_group(list(range(nprocs))) + comm._LOCAL_PROCESS_GROUP = pg + main_func(*args) + + +def ram_read_write_worker(): + schema = { + "tf": SizeData(dtype="float32", shape=(112, 112)), + "ti": SizeData(dtype="int32", shape=(4, 64, 64)), + } + storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) + world_size = comm.get_world_size() + rank = comm.get_rank() + data_elts = [] + # prepare different number of tensors in different processes + for i in range(rank + 1): + data_elt = { + "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), + "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), + } + data_elts.append(data_elt) + # write data to the single process storage + for i in range(rank + 1): + record_id = storage.put(data_elts[i]) + assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" + comm.synchronize() + # gather all data in process rank 0 + multi_storage = storage_gather(storage) + if rank != 0: + return + # read and check data from the multiprocess storage + for j in range(world_size): + for i in range(j): + record = multi_storage.get(j, i) + record_gt = { + "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), + "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), + } + assert len(record) == len(schema), ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"expected {len(schema)} fields in the record, got {len(record)}" + ) + for field_name in schema: + assert field_name in record, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name} not in the record" + ) + + assert record_gt[field_name].shape == record[field_name].shape, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, expected shape {record_gt[field_name].shape} " + f"got {record[field_name].shape}" + ) + assert record_gt[field_name].dtype == record[field_name].dtype, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, expected dtype {record_gt[field_name].dtype} " + f"got {record[field_name].dtype}" + ) + assert torch.allclose(record_gt[field_name], record[field_name]), ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, tensors are not close enough:" + f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " + f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " + ) + + +def file_read_write_worker(rank_to_fpath): + schema = { + "tf": SizeData(dtype="float32", shape=(112, 112)), + "ti": SizeData(dtype="int32", shape=(4, 64, 64)), + } + world_size = comm.get_world_size() + rank = comm.get_rank() + storage = SingleProcessFileTensorStorage(schema, rank_to_fpath[rank], "wb") + data_elts = [] + # prepare different number of tensors in different processes + for i in range(rank + 1): + data_elt = { + "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), + "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), + } + data_elts.append(data_elt) + # write data to the single process storage + for i in range(rank + 1): + record_id = storage.put(data_elts[i]) + assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" + comm.synchronize() + # gather all data in process rank 0 + multi_storage = storage_gather(storage) + if rank != 0: + return + # read and check data from the multiprocess storage + for j in range(world_size): + for i in range(j): + record = multi_storage.get(j, i) + record_gt = { + "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), + "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), + } + assert len(record) == len(schema), ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"expected {len(schema)} fields in the record, got {len(record)}" + ) + for field_name in schema: + assert field_name in record, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name} not in the record" + ) + + assert record_gt[field_name].shape == record[field_name].shape, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, expected shape {record_gt[field_name].shape} " + f"got {record[field_name].shape}" + ) + assert record_gt[field_name].dtype == record[field_name].dtype, ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, expected dtype {record_gt[field_name].dtype} " + f"got {record[field_name].dtype}" + ) + assert torch.allclose(record_gt[field_name], record[field_name]), ( + f"Process {rank}: multi storage record, rank {j}, id {i}: " + f"field {field_name}, tensors are not close enough:" + f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " + f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " + ) + + +class TestMultiProcessRamTensorStorage(unittest.TestCase): + def test_read_write_1(self): + launch(ram_read_write_worker, 8) + + +class TestMultiProcessFileTensorStorage(unittest.TestCase): + def test_read_write_1(self): + with ExitStack() as stack: + # WARNING: opens the files several times! may not work on all platforms + rank_to_fpath = { + i: stack.enter_context(tempfile.NamedTemporaryFile()).name for i in range(8) + } + launch(file_read_write_worker, 8, (rank_to_fpath,)) diff --git a/data_processing/detectron2/projects/DensePose/tests/test_video_keyframe_dataset.py b/data_processing/detectron2/projects/DensePose/tests/test_video_keyframe_dataset.py new file mode 100644 index 0000000..988e161 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/tests/test_video_keyframe_dataset.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +import os +import random +import tempfile +import unittest +import torch +import torchvision.io as io + +from densepose.data.transform import ImageResizeTransform +from densepose.data.video import RandomKFramesSelector, VideoKeyframeDataset + +try: + import av +except ImportError: + av = None + + +# copied from torchvision test/test_io.py +def _create_video_frames(num_frames, height, width): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + return torch.stack(data, 0) + + +# adapted from torchvision test/test_io.py +@contextlib.contextmanager +def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): + if lossless: + if video_codec is not None: + raise ValueError("video_codec can't be specified together with lossless") + if options is not None: + raise ValueError("options can't be specified together with lossless") + video_codec = "libx264rgb" + options = {"crf": "0"} + if video_codec is None: + video_codec = "libx264" + if options is None: + options = {} + data = _create_video_frames(num_frames, height, width) + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.close() + io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) + yield f.name, data + os.unlink(f.name) + + +@unittest.skipIf(av is None, "PyAV unavailable") +class TestVideoKeyframeDataset(unittest.TestCase): + def test_read_keyframes_all(self): + with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): + video_list = [fname] + category_list = [None] + dataset = VideoKeyframeDataset(video_list, category_list) + self.assertEqual(len(dataset), 1) + data1, categories1 = dataset[0]["images"], dataset[0]["categories"] + self.assertEqual(data1.shape, torch.Size((5, 3, 300, 300))) + self.assertEqual(data1.dtype, torch.float32) + self.assertIsNone(categories1[0]) + return + self.assertTrue(False) + + def test_read_keyframes_with_selector(self): + with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): + video_list = [fname] + category_list = [None] + random.seed(0) + frame_selector = RandomKFramesSelector(3) + dataset = VideoKeyframeDataset(video_list, category_list, frame_selector) + self.assertEqual(len(dataset), 1) + data1, categories1 = dataset[0]["images"], dataset[0]["categories"] + self.assertEqual(data1.shape, torch.Size((3, 3, 300, 300))) + self.assertEqual(data1.dtype, torch.float32) + self.assertIsNone(categories1[0]) + return + self.assertTrue(False) + + def test_read_keyframes_with_selector_with_transform(self): + with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): + video_list = [fname] + category_list = [None] + random.seed(0) + frame_selector = RandomKFramesSelector(1) + transform = ImageResizeTransform() + dataset = VideoKeyframeDataset(video_list, category_list, frame_selector, transform) + data1, categories1 = dataset[0]["images"], dataset[0]["categories"] + self.assertEqual(len(dataset), 1) + self.assertEqual(data1.shape, torch.Size((1, 3, 800, 800))) + self.assertEqual(data1.dtype, torch.float32) + self.assertIsNone(categories1[0]) + return + self.assertTrue(False) diff --git a/data_processing/detectron2/projects/DensePose/train_net.py b/data_processing/detectron2/projects/DensePose/train_net.py new file mode 100644 index 0000000..e8d77b9 --- /dev/null +++ b/data_processing/detectron2/projects/DensePose/train_net.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +DensePose Training Script. + +This script is similar to the training script in detectron2/tools. + +It is an example of how a user might use detectron2 for a new project. +""" + +from datetime import timedelta + +import detectron2.utils.comm as comm +from detectron2.config import get_cfg +from detectron2.engine import DEFAULT_TIMEOUT, default_argument_parser, default_setup, hooks, launch +from detectron2.evaluation import verify_results +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + +from densepose import add_densepose_config +from densepose.engine import Trainer +from densepose.modeling.densepose_checkpoint import DensePoseCheckpointer + + +def setup(args): + cfg = get_cfg() + add_densepose_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + # Setup logger for "densepose" module + setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="densepose") + return cfg + + +def main(args): + cfg = setup(args) + # disable strict kwargs checking: allow one to specify path handle + # hints through kwargs, like timeout in DP evaluation + PathManager.set_strict_kwargs_checking(False) + + if args.eval_only: + model = Trainer.build_model(cfg) + DensePoseCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + if cfg.TEST.AUG.ENABLED: + trainer.register_hooks( + [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] + ) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + cfg = setup(args) + timeout = ( + DEFAULT_TIMEOUT if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE else timedelta(hours=4) + ) + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + timeout=timeout, + ) diff --git a/data_processing/detectron2/projects/MViTv2/README.md b/data_processing/detectron2/projects/MViTv2/README.md new file mode 100644 index 0000000..64afd79 --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/README.md @@ -0,0 +1,142 @@ +# MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + +Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, Christoph Feichtenhofer* + +[[`arXiv`](https://arxiv.org/abs/2112.01526)] [[`BibTeX`](#CitingMViTv2)] + +In this repository, we provide detection configs and models for MViTv2 (CVPR 2022) in Detectron2. For image classification tasks, please refer to [MViTv2 repo](https://github.com/facebookresearch/mvit). + +## Results and Pretrained Models + +### COCO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-trainMethodepochsbox
AP
mask
AP
#paramsFLOPSmodel iddownload
MViTV2-TIN1KMask R-CNN3648.343.844M279G307611773model
MViTV2-TIN1KCascade Mask R-CNN3652.245.076M701G308344828model
MViTV2-SIN1KCascade Mask R-CNN3653.246.087M748G308344647model
MViTV2-BIN1KCascade Mask R-CNN3654.146.7103M814G308109448model
MViTV2-BIN21KCascade Mask R-CNN3654.947.4103M814G309003202model
MViTV2-LIN21KCascade Mask R-CNN5055.848.3270M1519G308099658model
MViTV2-HIN21KCascade Mask R-CNN3656.148.5718M3084G309013744model
+ +Note that the above models were trained and measured on 8-node with 64 NVIDIA A100 GPUs in total. The ImageNet pre-trained model weights are obtained from [MViTv2 repo](https://github.com/facebookresearch/mvit). + +## Training +All configs can be trained with: + +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py +``` +By default, we use 64 GPUs with batch size as 64 for training. + +## Evaluation +Model evaluation can be done similarly: +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint +``` + + + +## Citing MViTv2 + +If you use MViTv2, please use the following BibTeX entry. + +```BibTeX +@inproceedings{li2021improved, + title={MViTv2: Improved multiscale vision transformers for classification and detection}, + author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph}, + booktitle={CVPR}, + year={2022} +} +``` diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_3x.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_3x.py new file mode 100644 index 0000000..61366bf --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_3x.py @@ -0,0 +1,8 @@ +from .cascade_mask_rcnn_mvitv2_t_3x import model, dataloader, optimizer, lr_multiplier, train + + +model.backbone.bottom_up.depth = 24 +model.backbone.bottom_up.last_block_indexes = (1, 4, 20, 23) +model.backbone.bottom_up.drop_path_rate = 0.4 + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_B_in1k.pyth" diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_in21k_3x.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_in21k_3x.py new file mode 100644 index 0000000..7c3bdce --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_b_in21k_3x.py @@ -0,0 +1,3 @@ +from .cascade_mask_rcnn_mvitv2_b_3x import model, dataloader, optimizer, lr_multiplier, train + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_B_in21k.pyth" diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_h_in21k_lsj_3x.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_h_in21k_lsj_3x.py new file mode 100644 index 0000000..6fee5e9 --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_h_in21k_lsj_3x.py @@ -0,0 +1,12 @@ +from .cascade_mask_rcnn_mvitv2_b_3x import model, optimizer, train, lr_multiplier +from .common.coco_loader_lsj import dataloader + + +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.depth = 80 +model.backbone.bottom_up.num_heads = 3 +model.backbone.bottom_up.last_block_indexes = (3, 11, 71, 79) +model.backbone.bottom_up.drop_path_rate = 0.6 +model.backbone.bottom_up.use_act_checkpoint = True + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_H_in21k.pyth" diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_l_in21k_lsj_50ep.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_l_in21k_lsj_50ep.py new file mode 100644 index 0000000..38da895 --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_l_in21k_lsj_50ep.py @@ -0,0 +1,31 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + +from .cascade_mask_rcnn_mvitv2_b_3x import model, optimizer, train +from .common.coco_loader_lsj import dataloader + + +model.backbone.bottom_up.embed_dim = 144 +model.backbone.bottom_up.depth = 48 +model.backbone.bottom_up.num_heads = 2 +model.backbone.bottom_up.last_block_indexes = (1, 7, 43, 47) +model.backbone.bottom_up.drop_path_rate = 0.5 + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_L_in21k.pyth" + +# Schedule +# 50ep = 184375 // 2 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 // 2 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889 // 2, 177546 // 2], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +optimizer.lr = 1e-4 diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_s_3x.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_s_3x.py new file mode 100644 index 0000000..ad8eeb4 --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_s_3x.py @@ -0,0 +1,7 @@ +from .cascade_mask_rcnn_mvitv2_t_3x import model, dataloader, optimizer, lr_multiplier, train + + +model.backbone.bottom_up.depth = 16 +model.backbone.bottom_up.last_block_indexes = (0, 2, 13, 15) + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_S_in1k.pyth" diff --git a/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_t_3x.py b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_t_3x.py new file mode 100644 index 0000000..51327dd --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/cascade_mask_rcnn_mvitv2_t_3x.py @@ -0,0 +1,48 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead, CascadeROIHeads +from detectron2.layers.batch_norm import NaiveSyncBatchNorm + +from .mask_rcnn_mvitv2_t_3x import model, dataloader, optimizer, lr_multiplier, train + + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm=lambda c: NaiveSyncBatchNorm(c, stats_mode="N"), + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) + +# Using NaiveSyncBatchNorm becase heads may have empty input. That is not supported by +# torch.nn.SyncBatchNorm. We can remove this after +# https://github.com/pytorch/pytorch/issues/36530 is fixed. +model.roi_heads.mask_head.conv_norm = lambda c: NaiveSyncBatchNorm(c, stats_mode="N") + +# 2conv in RPN: +# https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/detection/modeling/architecture/heads.py#L95-L97 # noqa: E501, B950 +model.proposal_generator.head.conv_dims = [-1, -1] diff --git a/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader.py b/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader.py new file mode 100644 index 0000000..923878b --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader.py @@ -0,0 +1,59 @@ +from omegaconf import OmegaConf + +import detectron2.data.transforms as T +from detectron2.config import LazyCall as L +from detectron2.data import ( + DatasetMapper, + build_detection_test_loader, + build_detection_train_loader, + get_detection_dataset_dicts, +) +from detectron2.evaluation import COCOEvaluator + +dataloader = OmegaConf.create() + +dataloader.train = L(build_detection_train_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_train"), + mapper=L(DatasetMapper)( + is_train=True, + augmentations=[ + L(T.RandomApply)( + tfm_or_aug=L(T.AugmentationList)( + augs=[ + L(T.ResizeShortestEdge)( + short_edge_length=[400, 500, 600], sample_style="choice" + ), + L(T.RandomCrop)(crop_type="absolute_range", crop_size=(384, 600)), + ] + ), + prob=0.5, + ), + L(T.ResizeShortestEdge)( + short_edge_length=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), + sample_style="choice", + max_size=1333, + ), + L(T.RandomFlip)(horizontal=True), + ], + image_format="RGB", + use_instance_mask=True, + ), + total_batch_size=16, + num_workers=4, +) + +dataloader.test = L(build_detection_test_loader)( + dataset=L(get_detection_dataset_dicts)(names="coco_2017_val", filter_empty=False), + mapper=L(DatasetMapper)( + is_train=False, + augmentations=[ + L(T.ResizeShortestEdge)(short_edge_length=800, max_size=1333), + ], + image_format="${...train.mapper.image_format}", + ), + num_workers=4, +) + +dataloader.evaluator = L(COCOEvaluator)( + dataset_name="${..test.dataset.names}", +) diff --git a/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader_lsj.py b/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader_lsj.py new file mode 100644 index 0000000..019b21f --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/common/coco_loader_lsj.py @@ -0,0 +1,19 @@ +import detectron2.data.transforms as T +from detectron2 import model_zoo +from detectron2.config import LazyCall as L + +from .coco_loader import dataloader + +# Data using LSJ +image_size = 1024 +dataloader.train.mapper.augmentations = [ + L(T.RandomFlip)(horizontal=True), # flip first + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), +] +dataloader.train.mapper.image_format = "RGB" +dataloader.train.total_batch_size = 64 +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True diff --git a/data_processing/detectron2/projects/MViTv2/configs/mask_rcnn_mvitv2_t_3x.py b/data_processing/detectron2/projects/MViTv2/configs/mask_rcnn_mvitv2_t_3x.py new file mode 100644 index 0000000..ba4bdfe --- /dev/null +++ b/data_processing/detectron2/projects/MViTv2/configs/mask_rcnn_mvitv2_t_3x.py @@ -0,0 +1,55 @@ +from functools import partial +import torch.nn as nn +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling import MViT + +from .common.coco_loader import dataloader + +model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model +constants = model_zoo.get_config("common/data/constants.py").constants +model.pixel_mean = constants.imagenet_rgb256_mean +model.pixel_std = constants.imagenet_rgb256_std +model.input_format = "RGB" +model.backbone.bottom_up = L(MViT)( + embed_dim=96, + depth=10, + num_heads=1, + last_block_indexes=(0, 2, 7, 9), + residual_pooling=True, + drop_path_rate=0.2, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + out_features=("scale2", "scale3", "scale4", "scale5"), +) +model.backbone.in_features = "${.bottom_up.out_features}" + + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_T_in1k.pyth" + +dataloader.train.total_batch_size = 64 + +# 36 epochs +train.max_iter = 67500 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[52500, 62500, 67500], + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.overrides = { + "pos_embed": {"weight_decay": 0.0}, + "rel_pos_h": {"weight_decay": 0.0}, + "rel_pos_w": {"weight_decay": 0.0}, +} +optimizer.lr = 1.6e-4 diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/README.md b/data_processing/detectron2/projects/Panoptic-DeepLab/README.md new file mode 100644 index 0000000..86b6d42 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/README.md @@ -0,0 +1,175 @@ +# Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation + +Bowen Cheng, Maxwell D. Collins, Yukun Zhu, Ting Liu, Thomas S. Huang, Hartwig Adam, Liang-Chieh Chen + +[[`arXiv`](https://arxiv.org/abs/1911.10194)] [[`BibTeX`](#CitingPanopticDeepLab)] [[`Reference implementation`](https://github.com/bowenc0221/panoptic-deeplab)] + +
+ +

+ +## Installation +Install Detectron2 following [the instructions](https://detectron2.readthedocs.io/tutorials/install.html). +To use cityscapes, prepare data follow the [tutorial](https://detectron2.readthedocs.io/tutorials/builtin_datasets.html#expected-dataset-structure-for-cityscapes). + +## Training + +To train a model with 8 GPUs run: +```bash +cd /path/to/detectron2/projects/Panoptic-DeepLab +python train_net.py --config-file configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly: +```bash +cd /path/to/detectron2/projects/Panoptic-DeepLab +python train_net.py --config-file configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint +``` + +## Benchmark network speed + +If you want to benchmark the network speed without post-processing, you can run the evaluation script with `MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED True`: +```bash +cd /path/to/detectron2/projects/Panoptic-DeepLab +python train_net.py --config-file configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED True +``` + +## Cityscapes Panoptic Segmentation +Cityscapes models are trained with ImageNet pretraining. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodBackboneOutput
resolution
PQSQRQmIoUAPMemory (M)model iddownload
Panoptic-DeepLabR50-DC51024×2048 58.6 80.9 71.2 75.9 29.8 8668 - model | metrics
Panoptic-DeepLabR52-DC51024×2048 60.3 81.5 72.9 78.2 33.2 9682 30841561 model | metrics
Panoptic-DeepLab (DSConv)R52-DC51024×2048 60.3 81.0 73.2 78.7 32.1 10466 33148034 model | metrics
+ +Note: +- [R52](https://dl.fbaipublicfiles.com/detectron2/DeepLab/R-52.pkl): a ResNet-50 with its first 7x7 convolution replaced by 3 3x3 convolutions. This modification has been used in most semantic segmentation papers. We pre-train this backbone on ImageNet using the default recipe of [pytorch examples](https://github.com/pytorch/examples/tree/master/imagenet). +- DC5 means using dilated convolution in `res5`. +- We use a smaller training crop size (512x1024) than the original paper (1025x2049), we find using larger crop size (1024x2048) could further improve PQ by 1.5% but also degrades AP by 3%. +- The implementation with regular Conv2d in ASPP and head is much heavier head than the original paper. +- This implementation does not include optimized post-processing code needed for deployment. Post-processing the network + outputs now takes similar amount of time to the network itself. Please refer to speed in the + original paper for comparison. +- DSConv refers to using DepthwiseSeparableConv2d in ASPP and decoder. The implementation with DSConv is identical to the original paper. + +## COCO Panoptic Segmentation +COCO models are trained with ImageNet pretraining on 16 V100s. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodBackboneOutput
resolution
PQSQRQBox APMask APMemory (M)model iddownload
Panoptic-DeepLab (DSConv)R52-DC5640×640 35.5 77.3 44.7 18.6 19.7 246448865 model | metrics
+ +Note: +- [R52](https://dl.fbaipublicfiles.com/detectron2/DeepLab/R-52.pkl): a ResNet-50 with its first 7x7 convolution replaced by 3 3x3 convolutions. This modification has been used in most semantic segmentation papers. We pre-train this backbone on ImageNet using the default recipe of [pytorch examples](https://github.com/pytorch/examples/tree/master/imagenet). +- DC5 means using dilated convolution in `res5`. +- This reproduced number matches the original paper (35.5 vs. 35.1 PQ). +- This implementation does not include optimized post-processing code needed for deployment. Post-processing the network + outputs now takes more time than the network itself. Please refer to speed in the original paper for comparison. +- DSConv refers to using DepthwiseSeparableConv2d in ASPP and decoder. + +## Citing Panoptic-DeepLab + +If you use Panoptic-DeepLab, please use the following BibTeX entry. + +* CVPR 2020 paper: + +``` +@inproceedings{cheng2020panoptic, + title={Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation}, + author={Cheng, Bowen and Collins, Maxwell D and Zhu, Yukun and Liu, Ting and Huang, Thomas S and Adam, Hartwig and Chen, Liang-Chieh}, + booktitle={CVPR}, + year={2020} +} +``` + +* ICCV 2019 COCO-Mapillary workshp challenge report: + +``` +@inproceedings{cheng2019panoptic, + title={Panoptic-DeepLab}, + author={Cheng, Bowen and Collins, Maxwell D and Zhu, Yukun and Liu, Ting and Huang, Thomas S and Adam, Hartwig and Chen, Liang-Chieh}, + booktitle={ICCV COCO + Mapillary Joint Recognition Challenge Workshop}, + year={2019} +} +``` diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml new file mode 100644 index 0000000..6944c6f --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml @@ -0,0 +1,42 @@ +_BASE_: ../Cityscapes-PanopticSegmentation/Base-PanopticDeepLab-OS16.yaml +MODEL: + WEIGHTS: "detectron2://DeepLab/R-52.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + BACKBONE: + NAME: "build_resnet_deeplab_backbone" + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + RES5_MULTI_GRID: [1, 2, 4] + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False + SEM_SEG_HEAD: + NUM_CLASSES: 133 + LOSS_TOP_K: 1.0 + USE_DEPTHWISE_SEPARABLE_CONV: True + PANOPTIC_DEEPLAB: + STUFF_AREA: 4096 + NMS_KERNEL: 41 + SIZE_DIVISIBILITY: 640 + USE_DEPTHWISE_SEPARABLE_CONV: True +DATASETS: + TRAIN: ("coco_2017_train_panoptic",) + TEST: ("coco_2017_val_panoptic",) +SOLVER: + BASE_LR: 0.0005 + MAX_ITER: 200000 + IMS_PER_BATCH: 64 +INPUT: + FORMAT: "RGB" + GAUSSIAN_SIGMA: 8 + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 16)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 640 + MAX_SIZE_TRAIN: 960 + MAX_SIZE_TEST: 640 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (640, 640) diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/Base-PanopticDeepLab-OS16.yaml b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/Base-PanopticDeepLab-OS16.yaml new file mode 100644 index 0000000..b737998 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/Base-PanopticDeepLab-OS16.yaml @@ -0,0 +1,65 @@ +MODEL: + META_ARCHITECTURE: "PanopticDeepLab" + BACKBONE: + FREEZE_AT: 0 + RESNETS: + OUT_FEATURES: ["res2", "res3", "res5"] + RES5_DILATION: 2 + SEM_SEG_HEAD: + NAME: "PanopticDeepLabSemSegHead" + IN_FEATURES: ["res2", "res3", "res5"] + PROJECT_FEATURES: ["res2", "res3"] + PROJECT_CHANNELS: [32, 64] + ASPP_CHANNELS: 256 + ASPP_DILATIONS: [6, 12, 18] + ASPP_DROPOUT: 0.1 + HEAD_CHANNELS: 256 + CONVS_DIM: 256 + COMMON_STRIDE: 4 + NUM_CLASSES: 19 + LOSS_TYPE: "hard_pixel_mining" + NORM: "SyncBN" + INS_EMBED_HEAD: + NAME: "PanopticDeepLabInsEmbedHead" + IN_FEATURES: ["res2", "res3", "res5"] + PROJECT_FEATURES: ["res2", "res3"] + PROJECT_CHANNELS: [32, 64] + ASPP_CHANNELS: 256 + ASPP_DILATIONS: [6, 12, 18] + ASPP_DROPOUT: 0.1 + HEAD_CHANNELS: 32 + CONVS_DIM: 128 + COMMON_STRIDE: 4 + NORM: "SyncBN" + CENTER_LOSS_WEIGHT: 200.0 + OFFSET_LOSS_WEIGHT: 0.01 + PANOPTIC_DEEPLAB: + STUFF_AREA: 2048 + CENTER_THRESHOLD: 0.1 + NMS_KERNEL: 7 + TOP_K_INSTANCE: 200 +DATASETS: + TRAIN: ("cityscapes_fine_panoptic_train",) + TEST: ("cityscapes_fine_panoptic_val",) +SOLVER: + OPTIMIZER: "ADAM" + BASE_LR: 0.001 + WEIGHT_DECAY: 0.0 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_BIAS: 0.0 + MAX_ITER: 60000 + LR_SCHEDULER_NAME: "WarmupPolyLR" + IMS_PER_BATCH: 32 +INPUT: + MIN_SIZE_TRAIN: (512, 640, 704, 832, 896, 1024, 1152, 1216, 1344, 1408, 1536, 1664, 1728, 1856, 1920, 2048) + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 1024 + MAX_SIZE_TRAIN: 4096 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (1024, 2048) +DATALOADER: + NUM_WORKERS: 10 +VERSION: 2 diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024.yaml b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024.yaml new file mode 100644 index 0000000..fde902b --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024.yaml @@ -0,0 +1,20 @@ +_BASE_: Base-PanopticDeepLab-OS16.yaml +MODEL: + WEIGHTS: "detectron2://DeepLab/R-52.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + BACKBONE: + NAME: "build_resnet_deeplab_backbone" + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + RES5_MULTI_GRID: [1, 2, 4] + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False +SOLVER: + MAX_ITER: 90000 +INPUT: + FORMAT: "RGB" + CROP: + SIZE: (512, 1024) diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml new file mode 100644 index 0000000..8e31420 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/configs/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_90k_bs32_crop_512_1024_dsconv.yaml @@ -0,0 +1,24 @@ +_BASE_: Base-PanopticDeepLab-OS16.yaml +MODEL: + WEIGHTS: "detectron2://DeepLab/R-52.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + BACKBONE: + NAME: "build_resnet_deeplab_backbone" + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + RES5_MULTI_GRID: [1, 2, 4] + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False + PANOPTIC_DEEPLAB: + USE_DEPTHWISE_SEPARABLE_CONV: True + SEM_SEG_HEAD: + USE_DEPTHWISE_SEPARABLE_CONV: True +SOLVER: + MAX_ITER: 90000 +INPUT: + FORMAT: "RGB" + CROP: + SIZE: (512, 1024) diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/__init__.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/__init__.py new file mode 100644 index 0000000..8d3c980 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import add_panoptic_deeplab_config +from .dataset_mapper import PanopticDeeplabDatasetMapper +from .panoptic_seg import ( + PanopticDeepLab, + INS_EMBED_BRANCHES_REGISTRY, + build_ins_embed_branch, + PanopticDeepLabSemSegHead, + PanopticDeepLabInsEmbedHead, +) diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/config.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/config.py new file mode 100644 index 0000000..5aa2d28 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/config.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.config import CfgNode as CN +from detectron2.projects.deeplab import add_deeplab_config + + +def add_panoptic_deeplab_config(cfg): + """ + Add config for Panoptic-DeepLab. + """ + # Reuse DeepLab config. + add_deeplab_config(cfg) + # Target generation parameters. + cfg.INPUT.GAUSSIAN_SIGMA = 10 + cfg.INPUT.IGNORE_STUFF_IN_OFFSET = True + cfg.INPUT.SMALL_INSTANCE_AREA = 4096 + cfg.INPUT.SMALL_INSTANCE_WEIGHT = 3 + cfg.INPUT.IGNORE_CROWD_IN_SEMANTIC = False + # Optimizer type. + cfg.SOLVER.OPTIMIZER = "ADAM" + # Panoptic-DeepLab semantic segmentation head. + # We add an extra convolution before predictor. + cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS = 256 + cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K = 0.2 + # Panoptic-DeepLab instance segmentation head. + cfg.MODEL.INS_EMBED_HEAD = CN() + cfg.MODEL.INS_EMBED_HEAD.NAME = "PanopticDeepLabInsEmbedHead" + cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES = ["res2", "res3", "res5"] + cfg.MODEL.INS_EMBED_HEAD.PROJECT_FEATURES = ["res2", "res3"] + cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS = [32, 64] + cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS = 256 + cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS = [6, 12, 18] + cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT = 0.1 + # We add an extra convolution before predictor. + cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS = 32 + cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM = 128 + cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE = 4 + cfg.MODEL.INS_EMBED_HEAD.NORM = "SyncBN" + cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT = 200.0 + cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT = 0.01 + # Panoptic-DeepLab post-processing setting. + cfg.MODEL.PANOPTIC_DEEPLAB = CN() + # Stuff area limit, ignore stuff region below this number. + cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA = 2048 + cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD = 0.1 + cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL = 7 + cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE = 200 + # If set to False, Panoptic-DeepLab will not evaluate instance segmentation. + cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES = True + cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV = False + # This is the padding parameter for images with various sizes. ASPP layers + # requires input images to be divisible by the average pooling size and we + # can use `MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY` to pad all images to + # a fixed resolution (e.g. 640x640 for COCO) to avoid having a image size + # that is not divisible by ASPP average pooling size. + cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY = -1 + # Only evaluates network speed (ignores post-processing). + cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED = False diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/dataset_mapper.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/dataset_mapper.py new file mode 100644 index 0000000..53272c7 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/dataset_mapper.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import numpy as np +from typing import Callable, List, Union +import torch +from panopticapi.utils import rgb2id + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T + +from .target_generator import PanopticDeepLabTargetGenerator + +__all__ = ["PanopticDeeplabDatasetMapper"] + + +class PanopticDeeplabDatasetMapper: + """ + The callable currently does the following: + + 1. Read the image from "file_name" and label from "pan_seg_file_name" + 2. Applies random scale, crop and flip transforms to image and label + 3. Prepare data to Tensor and generate training targets from label + """ + + @configurable + def __init__( + self, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + panoptic_target_generator: Callable, + ): + """ + NOTE: this interface is experimental. + + Args: + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + panoptic_target_generator: a callable that takes "panoptic_seg" and + "segments_info" to generate training targets for the model. + """ + # fmt: off + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + # fmt: on + logger = logging.getLogger(__name__) + logger.info("Augmentations used in training: " + str(augmentations)) + + self.panoptic_target_generator = panoptic_target_generator + + @classmethod + def from_config(cls, cfg): + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, + cfg.INPUT.MAX_SIZE_TRAIN, + cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append(T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + augs.append(T.RandomFlip()) + + # Assume always applies to the training set. + dataset_names = cfg.DATASETS.TRAIN + meta = MetadataCatalog.get(dataset_names[0]) + panoptic_target_generator = PanopticDeepLabTargetGenerator( + ignore_label=meta.ignore_label, + thing_ids=list(meta.thing_dataset_id_to_contiguous_id.values()), + sigma=cfg.INPUT.GAUSSIAN_SIGMA, + ignore_stuff_in_offset=cfg.INPUT.IGNORE_STUFF_IN_OFFSET, + small_instance_area=cfg.INPUT.SMALL_INSTANCE_AREA, + small_instance_weight=cfg.INPUT.SMALL_INSTANCE_WEIGHT, + ignore_crowd_in_semantic=cfg.INPUT.IGNORE_CROWD_IN_SEMANTIC, + ) + + ret = { + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "panoptic_target_generator": panoptic_target_generator, + } + return ret + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # Load image. + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + utils.check_image_size(dataset_dict, image) + # Panoptic label is encoded in RGB image. + pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") + + # Reuses semantic transform for panoptic labels. + aug_input = T.AugInput(image, sem_seg=pan_seg_gt) + _ = self.augmentations(aug_input) + image, pan_seg_gt = aug_input.image, aug_input.sem_seg + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + # Generates training targets for Panoptic-DeepLab. + targets = self.panoptic_target_generator(rgb2id(pan_seg_gt), dataset_dict["segments_info"]) + dataset_dict.update(targets) + + return dataset_dict diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py new file mode 100644 index 0000000..c12ca74 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py @@ -0,0 +1,572 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Callable, Dict, List, Union +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.layers import Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm +from detectron2.modeling import ( + META_ARCH_REGISTRY, + SEM_SEG_HEADS_REGISTRY, + build_backbone, + build_sem_seg_head, +) +from detectron2.modeling.postprocessing import sem_seg_postprocess +from detectron2.projects.deeplab import DeepLabV3PlusHead +from detectron2.projects.deeplab.loss import DeepLabCE +from detectron2.structures import BitMasks, ImageList, Instances +from detectron2.utils.registry import Registry + +from .post_processing import get_panoptic_segmentation + +__all__ = ["PanopticDeepLab", "INS_EMBED_BRANCHES_REGISTRY", "build_ins_embed_branch"] + + +INS_EMBED_BRANCHES_REGISTRY = Registry("INS_EMBED_BRANCHES") +INS_EMBED_BRANCHES_REGISTRY.__doc__ = """ +Registry for instance embedding branches, which make instance embedding +predictions from feature maps. +""" + + +@META_ARCH_REGISTRY.register() +class PanopticDeepLab(nn.Module): + """ + Main class for panoptic segmentation architectures. + """ + + def __init__(self, cfg): + super().__init__() + self.backbone = build_backbone(cfg) + self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) + self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape()) + self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) + self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) + self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA + self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD + self.nms_kernel = cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL + self.top_k = cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE + self.predict_instances = cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES + self.use_depthwise_separable_conv = cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV + assert ( + cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV + == cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV + ) + self.size_divisibility = cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY + self.benchmark_network_speed = cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "sem_seg": semantic segmentation ground truth + * "center": center points heatmap ground truth + * "offset": pixel offsets to center points ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict is the results for one image. The dict contains the following keys: + + * "panoptic_seg", "sem_seg": see documentation + :doc:`/tutorials/models` for the standard output format + * "instances": available if ``predict_instances is True``. see documentation + :doc:`/tutorials/models` for the standard output format + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + # To avoid error in ASPP layer when input has different size. + size_divisibility = ( + self.size_divisibility + if self.size_divisibility > 0 + else self.backbone.size_divisibility + ) + images = ImageList.from_tensors(images, size_divisibility) + + features = self.backbone(images.tensor) + + losses = {} + if "sem_seg" in batched_inputs[0]: + targets = [x["sem_seg"].to(self.device) for x in batched_inputs] + targets = ImageList.from_tensors( + targets, size_divisibility, self.sem_seg_head.ignore_value + ).tensor + if "sem_seg_weights" in batched_inputs[0]: + # The default D2 DatasetMapper may not contain "sem_seg_weights" + # Avoid error in testing when default DatasetMapper is used. + weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs] + weights = ImageList.from_tensors(weights, size_divisibility).tensor + else: + weights = None + else: + targets = None + weights = None + sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights) + losses.update(sem_seg_losses) + + if "center" in batched_inputs[0] and "offset" in batched_inputs[0]: + center_targets = [x["center"].to(self.device) for x in batched_inputs] + center_targets = ImageList.from_tensors( + center_targets, size_divisibility + ).tensor.unsqueeze(1) + center_weights = [x["center_weights"].to(self.device) for x in batched_inputs] + center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor + + offset_targets = [x["offset"].to(self.device) for x in batched_inputs] + offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor + offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs] + offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor + else: + center_targets = None + center_weights = None + + offset_targets = None + offset_weights = None + + center_results, offset_results, center_losses, offset_losses = self.ins_embed_head( + features, center_targets, center_weights, offset_targets, offset_weights + ) + losses.update(center_losses) + losses.update(offset_losses) + + if self.training: + return losses + + if self.benchmark_network_speed: + return [] + + processed_results = [] + for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip( + sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height") + width = input_per_image.get("width") + r = sem_seg_postprocess(sem_seg_result, image_size, height, width) + c = sem_seg_postprocess(center_result, image_size, height, width) + o = sem_seg_postprocess(offset_result, image_size, height, width) + # Post-processing to get panoptic segmentation. + panoptic_image, _ = get_panoptic_segmentation( + r.argmax(dim=0, keepdim=True), + c, + o, + thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(), + label_divisor=self.meta.label_divisor, + stuff_area=self.stuff_area, + void_label=-1, + threshold=self.threshold, + nms_kernel=self.nms_kernel, + top_k=self.top_k, + ) + # For semantic segmentation evaluation. + processed_results.append({"sem_seg": r}) + panoptic_image = panoptic_image.squeeze(0) + semantic_prob = F.softmax(r, dim=0) + # For panoptic segmentation evaluation. + processed_results[-1]["panoptic_seg"] = (panoptic_image, None) + # For instance segmentation evaluation. + if self.predict_instances: + instances = [] + panoptic_image_cpu = panoptic_image.cpu().numpy() + for panoptic_label in np.unique(panoptic_image_cpu): + if panoptic_label == -1: + continue + pred_class = panoptic_label // self.meta.label_divisor + isthing = pred_class in list( + self.meta.thing_dataset_id_to_contiguous_id.values() + ) + # Get instance segmentation results. + if isthing: + instance = Instances((height, width)) + # Evaluation code takes continuous id starting from 0 + instance.pred_classes = torch.tensor( + [pred_class], device=panoptic_image.device + ) + mask = panoptic_image == panoptic_label + instance.pred_masks = mask.unsqueeze(0) + # Average semantic probability + sem_scores = semantic_prob[pred_class, ...] + sem_scores = torch.mean(sem_scores[mask]) + # Center point probability + mask_indices = torch.nonzero(mask).float() + center_y, center_x = ( + torch.mean(mask_indices[:, 0]), + torch.mean(mask_indices[:, 1]), + ) + center_scores = c[0, int(center_y.item()), int(center_x.item())] + # Confidence score is semantic prob * center prob. + instance.scores = torch.tensor( + [sem_scores * center_scores], device=panoptic_image.device + ) + # Get bounding boxes + instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes() + instances.append(instance) + if len(instances) > 0: + processed_results[-1]["instances"] = Instances.cat(instances) + + return processed_results + + +@SEM_SEG_HEADS_REGISTRY.register() +class PanopticDeepLabSemSegHead(DeepLabV3PlusHead): + """ + A semantic segmentation head described in :paper:`Panoptic-DeepLab`. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + decoder_channels: List[int], + norm: Union[str, Callable], + head_channels: int, + loss_weight: float, + loss_type: str, + loss_top_k: float, + ignore_value: int, + num_classes: int, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + decoder_channels (list[int]): a list of output channels of each + decoder stage. It should have the same length as "input_shape" + (each element in "input_shape" corresponds to one decoder stage). + norm (str or callable): normalization for all conv layers. + head_channels (int): the output channels of extra convolutions + between decoder and predictor. + loss_weight (float): loss weight. + loss_top_k: (float): setting the top k% hardest pixels for + "hard_pixel_mining" loss. + loss_type, ignore_value, num_classes: the same as the base class. + """ + super().__init__( + input_shape, + decoder_channels=decoder_channels, + norm=norm, + ignore_value=ignore_value, + **kwargs, + ) + assert self.decoder_only + + self.loss_weight = loss_weight + use_bias = norm == "" + # `head` is additional transform before predictor + if self.use_depthwise_separable_conv: + # We use a single 5x5 DepthwiseSeparableConv2d to replace + # 2 3x3 Conv2d since they have the same receptive field. + self.head = DepthwiseSeparableConv2d( + decoder_channels[0], + head_channels, + kernel_size=5, + padding=2, + norm1=norm, + activation1=F.relu, + norm2=norm, + activation2=F.relu, + ) + else: + self.head = nn.Sequential( + Conv2d( + decoder_channels[0], + decoder_channels[0], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[0]), + activation=F.relu, + ), + Conv2d( + decoder_channels[0], + head_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, head_channels), + activation=F.relu, + ), + ) + weight_init.c2_xavier_fill(self.head[0]) + weight_init.c2_xavier_fill(self.head[1]) + self.predictor = Conv2d(head_channels, num_classes, kernel_size=1) + nn.init.normal_(self.predictor.weight, 0, 0.001) + nn.init.constant_(self.predictor.bias, 0) + + if loss_type == "cross_entropy": + self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_value) + elif loss_type == "hard_pixel_mining": + self.loss = DeepLabCE(ignore_label=ignore_value, top_k_percent_pixels=loss_top_k) + else: + raise ValueError("Unexpected loss type: %s" % loss_type) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret["head_channels"] = cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS + ret["loss_top_k"] = cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K + return ret + + def forward(self, features, targets=None, weights=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + y = self.layers(features) + if self.training: + return None, self.losses(y, targets, weights) + else: + y = F.interpolate( + y, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return y, {} + + def layers(self, features): + assert self.decoder_only + y = super().layers(features) + y = self.head(y) + y = self.predictor(y) + return y + + def losses(self, predictions, targets, weights=None): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.loss(predictions, targets, weights) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses + + +def build_ins_embed_branch(cfg, input_shape): + """ + Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. + """ + name = cfg.MODEL.INS_EMBED_HEAD.NAME + return INS_EMBED_BRANCHES_REGISTRY.get(name)(cfg, input_shape) + + +@INS_EMBED_BRANCHES_REGISTRY.register() +class PanopticDeepLabInsEmbedHead(DeepLabV3PlusHead): + """ + A instance embedding head described in :paper:`Panoptic-DeepLab`. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + decoder_channels: List[int], + norm: Union[str, Callable], + head_channels: int, + center_loss_weight: float, + offset_loss_weight: float, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + decoder_channels (list[int]): a list of output channels of each + decoder stage. It should have the same length as "input_shape" + (each element in "input_shape" corresponds to one decoder stage). + norm (str or callable): normalization for all conv layers. + head_channels (int): the output channels of extra convolutions + between decoder and predictor. + center_loss_weight (float): loss weight for center point prediction. + offset_loss_weight (float): loss weight for center offset prediction. + """ + super().__init__(input_shape, decoder_channels=decoder_channels, norm=norm, **kwargs) + assert self.decoder_only + + self.center_loss_weight = center_loss_weight + self.offset_loss_weight = offset_loss_weight + use_bias = norm == "" + # center prediction + # `head` is additional transform before predictor + self.center_head = nn.Sequential( + Conv2d( + decoder_channels[0], + decoder_channels[0], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[0]), + activation=F.relu, + ), + Conv2d( + decoder_channels[0], + head_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, head_channels), + activation=F.relu, + ), + ) + weight_init.c2_xavier_fill(self.center_head[0]) + weight_init.c2_xavier_fill(self.center_head[1]) + self.center_predictor = Conv2d(head_channels, 1, kernel_size=1) + nn.init.normal_(self.center_predictor.weight, 0, 0.001) + nn.init.constant_(self.center_predictor.bias, 0) + + # offset prediction + # `head` is additional transform before predictor + if self.use_depthwise_separable_conv: + # We use a single 5x5 DepthwiseSeparableConv2d to replace + # 2 3x3 Conv2d since they have the same receptive field. + self.offset_head = DepthwiseSeparableConv2d( + decoder_channels[0], + head_channels, + kernel_size=5, + padding=2, + norm1=norm, + activation1=F.relu, + norm2=norm, + activation2=F.relu, + ) + else: + self.offset_head = nn.Sequential( + Conv2d( + decoder_channels[0], + decoder_channels[0], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[0]), + activation=F.relu, + ), + Conv2d( + decoder_channels[0], + head_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, head_channels), + activation=F.relu, + ), + ) + weight_init.c2_xavier_fill(self.offset_head[0]) + weight_init.c2_xavier_fill(self.offset_head[1]) + self.offset_predictor = Conv2d(head_channels, 2, kernel_size=1) + nn.init.normal_(self.offset_predictor.weight, 0, 0.001) + nn.init.constant_(self.offset_predictor.bias, 0) + + self.center_loss = nn.MSELoss(reduction="none") + self.offset_loss = nn.L1Loss(reduction="none") + + @classmethod + def from_config(cls, cfg, input_shape): + if cfg.INPUT.CROP.ENABLED: + assert cfg.INPUT.CROP.TYPE == "absolute" + train_size = cfg.INPUT.CROP.SIZE + else: + train_size = None + decoder_channels = [cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM] * ( + len(cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES) - 1 + ) + [cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS] + ret = dict( + input_shape={ + k: v for k, v in input_shape.items() if k in cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES + }, + project_channels=cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS, + aspp_dilations=cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS, + aspp_dropout=cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT, + decoder_channels=decoder_channels, + common_stride=cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE, + norm=cfg.MODEL.INS_EMBED_HEAD.NORM, + train_size=train_size, + head_channels=cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS, + center_loss_weight=cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT, + offset_loss_weight=cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT, + use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV, + ) + return ret + + def forward( + self, + features, + center_targets=None, + center_weights=None, + offset_targets=None, + offset_weights=None, + ): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + center, offset = self.layers(features) + if self.training: + return ( + None, + None, + self.center_losses(center, center_targets, center_weights), + self.offset_losses(offset, offset_targets, offset_weights), + ) + else: + center = F.interpolate( + center, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + offset = ( + F.interpolate( + offset, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + * self.common_stride + ) + return center, offset, {}, {} + + def layers(self, features): + assert self.decoder_only + y = super().layers(features) + # center + center = self.center_head(y) + center = self.center_predictor(center) + # offset + offset = self.offset_head(y) + offset = self.offset_predictor(offset) + return center, offset + + def center_losses(self, predictions, targets, weights): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.center_loss(predictions, targets) * weights + if weights.sum() > 0: + loss = loss.sum() / weights.sum() + else: + loss = loss.sum() * 0 + losses = {"loss_center": loss * self.center_loss_weight} + return losses + + def offset_losses(self, predictions, targets, weights): + predictions = ( + F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + * self.common_stride + ) + loss = self.offset_loss(predictions, targets) * weights + if weights.sum() > 0: + loss = loss.sum() / weights.sum() + else: + loss = loss.sum() * 0 + losses = {"loss_offset": loss * self.offset_loss_weight} + return losses diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/post_processing.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/post_processing.py new file mode 100644 index 0000000..194724e --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/post_processing.py @@ -0,0 +1,234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Reference: https://github.com/bowenc0221/panoptic-deeplab/blob/master/segmentation/model/post_processing/instance_post_processing.py # noqa + +from collections import Counter +import torch +import torch.nn.functional as F + + +def find_instance_center(center_heatmap, threshold=0.1, nms_kernel=3, top_k=None): + """ + Find the center points from the center heatmap. + Args: + center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. + threshold: A float, threshold applied to center heatmap score. + nms_kernel: An integer, NMS max pooling kernel size. + top_k: An integer, top k centers to keep. + Returns: + A Tensor of shape [K, 2] where K is the number of center points. The + order of second dim is (y, x). + """ + # Thresholding, setting values below threshold to -1. + center_heatmap = F.threshold(center_heatmap, threshold, -1) + + # NMS + nms_padding = (nms_kernel - 1) // 2 + center_heatmap_max_pooled = F.max_pool2d( + center_heatmap, kernel_size=nms_kernel, stride=1, padding=nms_padding + ) + center_heatmap[center_heatmap != center_heatmap_max_pooled] = -1 + + # Squeeze first two dimensions. + center_heatmap = center_heatmap.squeeze() + assert len(center_heatmap.size()) == 2, "Something is wrong with center heatmap dimension." + + # Find non-zero elements. + if top_k is None: + return torch.nonzero(center_heatmap > 0) + else: + # find top k centers. + top_k_scores, _ = torch.topk(torch.flatten(center_heatmap), top_k) + return torch.nonzero(center_heatmap > top_k_scores[-1].clamp_(min=0)) + + +def group_pixels(center_points, offsets): + """ + Gives each pixel in the image an instance id. + Args: + center_points: A Tensor of shape [K, 2] where K is the number of center points. + The order of second dim is (y, x). + offsets: A Tensor of shape [2, H, W] of raw offset output. The order of + second dim is (offset_y, offset_x). + Returns: + A Tensor of shape [1, H, W] with values in range [1, K], which represents + the center this pixel belongs to. + """ + height, width = offsets.size()[1:] + + # Generates a coordinate map, where each location is the coordinate of + # that location. + y_coord, x_coord = torch.meshgrid( + torch.arange(height, dtype=offsets.dtype, device=offsets.device), + torch.arange(width, dtype=offsets.dtype, device=offsets.device), + ) + coord = torch.cat((y_coord.unsqueeze(0), x_coord.unsqueeze(0)), dim=0) + + center_loc = coord + offsets + center_loc = center_loc.flatten(1).T.unsqueeze_(0) # [1, H*W, 2] + center_points = center_points.unsqueeze(1) # [K, 1, 2] + + # Distance: [K, H*W]. + distance = torch.norm(center_points - center_loc, dim=-1) + + # Finds center with minimum distance at each location, offset by 1, to + # reserve id=0 for stuff. + instance_id = torch.argmin(distance, dim=0).reshape((1, height, width)) + 1 + return instance_id + + +def get_instance_segmentation( + sem_seg, center_heatmap, offsets, thing_seg, thing_ids, threshold=0.1, nms_kernel=3, top_k=None +): + """ + Post-processing for instance segmentation, gets class agnostic instance id. + Args: + sem_seg: A Tensor of shape [1, H, W], predicted semantic label. + center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. + offsets: A Tensor of shape [2, H, W] of raw offset output. The order of + second dim is (offset_y, offset_x). + thing_seg: A Tensor of shape [1, H, W], predicted foreground mask, + if not provided, inference from semantic prediction. + thing_ids: A set of ids from contiguous category ids belonging + to thing categories. + threshold: A float, threshold applied to center heatmap score. + nms_kernel: An integer, NMS max pooling kernel size. + top_k: An integer, top k centers to keep. + Returns: + A Tensor of shape [1, H, W] with value 0 represent stuff (not instance) + and other positive values represent different instances. + A Tensor of shape [1, K, 2] where K is the number of center points. + The order of second dim is (y, x). + """ + center_points = find_instance_center( + center_heatmap, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k + ) + if center_points.size(0) == 0: + return torch.zeros_like(sem_seg), center_points.unsqueeze(0) + ins_seg = group_pixels(center_points, offsets) + return thing_seg * ins_seg, center_points.unsqueeze(0) + + +def merge_semantic_and_instance( + sem_seg, ins_seg, semantic_thing_seg, label_divisor, thing_ids, stuff_area, void_label +): + """ + Post-processing for panoptic segmentation, by merging semantic segmentation + label and class agnostic instance segmentation label. + Args: + sem_seg: A Tensor of shape [1, H, W], predicted category id for each pixel. + ins_seg: A Tensor of shape [1, H, W], predicted instance id for each pixel. + semantic_thing_seg: A Tensor of shape [1, H, W], predicted foreground mask. + label_divisor: An integer, used to convert panoptic id = + semantic id * label_divisor + instance_id. + thing_ids: Set, a set of ids from contiguous category ids belonging + to thing categories. + stuff_area: An integer, remove stuff whose area is less tan stuff_area. + void_label: An integer, indicates the region has no confident prediction. + Returns: + A Tensor of shape [1, H, W]. + """ + # In case thing mask does not align with semantic prediction. + pan_seg = torch.zeros_like(sem_seg) + void_label + is_thing = (ins_seg > 0) & (semantic_thing_seg > 0) + + # Keep track of instance id for each class. + class_id_tracker = Counter() + + # Paste thing by majority voting. + instance_ids = torch.unique(ins_seg) + for ins_id in instance_ids: + if ins_id == 0: + continue + # Make sure only do majority voting within `semantic_thing_seg`. + thing_mask = (ins_seg == ins_id) & is_thing + if torch.nonzero(thing_mask).size(0) == 0: + continue + class_id, _ = torch.mode(sem_seg[thing_mask].view(-1)) + class_id_tracker[class_id.item()] += 1 + new_ins_id = class_id_tracker[class_id.item()] + pan_seg[thing_mask] = class_id * label_divisor + new_ins_id + + # Paste stuff to unoccupied area. + class_ids = torch.unique(sem_seg) + for class_id in class_ids: + if class_id.item() in thing_ids: + # thing class + continue + # Calculate stuff area. + stuff_mask = (sem_seg == class_id) & (ins_seg == 0) + if stuff_mask.sum().item() >= stuff_area: + pan_seg[stuff_mask] = class_id * label_divisor + + return pan_seg + + +def get_panoptic_segmentation( + sem_seg, + center_heatmap, + offsets, + thing_ids, + label_divisor, + stuff_area, + void_label, + threshold=0.1, + nms_kernel=7, + top_k=200, + foreground_mask=None, +): + """ + Post-processing for panoptic segmentation. + Args: + sem_seg: A Tensor of shape [1, H, W] of predicted semantic label. + center_heatmap: A Tensor of shape [1, H, W] of raw center heatmap output. + offsets: A Tensor of shape [2, H, W] of raw offset output. The order of + second dim is (offset_y, offset_x). + thing_ids: A set of ids from contiguous category ids belonging + to thing categories. + label_divisor: An integer, used to convert panoptic id = + semantic id * label_divisor + instance_id. + stuff_area: An integer, remove stuff whose area is less tan stuff_area. + void_label: An integer, indicates the region has no confident prediction. + threshold: A float, threshold applied to center heatmap score. + nms_kernel: An integer, NMS max pooling kernel size. + top_k: An integer, top k centers to keep. + foreground_mask: Optional, A Tensor of shape [1, H, W] of predicted + binary foreground mask. If not provided, it will be generated from + sem_seg. + Returns: + A Tensor of shape [1, H, W], int64. + """ + if sem_seg.dim() != 3 and sem_seg.size(0) != 1: + raise ValueError("Semantic prediction with un-supported shape: {}.".format(sem_seg.size())) + if center_heatmap.dim() != 3: + raise ValueError( + "Center prediction with un-supported dimension: {}.".format(center_heatmap.dim()) + ) + if offsets.dim() != 3: + raise ValueError("Offset prediction with un-supported dimension: {}.".format(offsets.dim())) + if foreground_mask is not None: + if foreground_mask.dim() != 3 and foreground_mask.size(0) != 1: + raise ValueError( + "Foreground prediction with un-supported shape: {}.".format(sem_seg.size()) + ) + thing_seg = foreground_mask + else: + # inference from semantic segmentation + thing_seg = torch.zeros_like(sem_seg) + for thing_class in list(thing_ids): + thing_seg[sem_seg == thing_class] = 1 + + instance, center = get_instance_segmentation( + sem_seg, + center_heatmap, + offsets, + thing_seg, + thing_ids, + threshold=threshold, + nms_kernel=nms_kernel, + top_k=top_k, + ) + panoptic = merge_semantic_and_instance( + sem_seg, instance, thing_seg, label_divisor, thing_ids, stuff_area, void_label + ) + + return panoptic, center diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/target_generator.py b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/target_generator.py new file mode 100644 index 0000000..a575c67 --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/panoptic_deeplab/target_generator.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Reference: https://github.com/bowenc0221/panoptic-deeplab/blob/aa934324b55a34ce95fea143aea1cb7a6dbe04bd/segmentation/data/transforms/target_transforms.py#L11 # noqa +import numpy as np +import torch + + +class PanopticDeepLabTargetGenerator(object): + """ + Generates training targets for Panoptic-DeepLab. + """ + + def __init__( + self, + ignore_label, + thing_ids, + sigma=8, + ignore_stuff_in_offset=False, + small_instance_area=0, + small_instance_weight=1, + ignore_crowd_in_semantic=False, + ): + """ + Args: + ignore_label: Integer, the ignore label for semantic segmentation. + thing_ids: Set, a set of ids from contiguous category ids belonging + to thing categories. + sigma: the sigma for Gaussian kernel. + ignore_stuff_in_offset: Boolean, whether to ignore stuff region when + training the offset branch. + small_instance_area: Integer, indicates largest area for small instances. + small_instance_weight: Integer, indicates semantic loss weights for + small instances. + ignore_crowd_in_semantic: Boolean, whether to ignore crowd region in + semantic segmentation branch, crowd region is ignored in the original + TensorFlow implementation. + """ + self.ignore_label = ignore_label + self.thing_ids = set(thing_ids) + self.ignore_stuff_in_offset = ignore_stuff_in_offset + self.small_instance_area = small_instance_area + self.small_instance_weight = small_instance_weight + self.ignore_crowd_in_semantic = ignore_crowd_in_semantic + + # Generate the default Gaussian image for each center + self.sigma = sigma + size = 6 * sigma + 3 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0, y0 = 3 * sigma + 1, 3 * sigma + 1 + self.g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2)) + + def __call__(self, panoptic, segments_info): + """Generates the training target. + reference: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createPanopticImgs.py # noqa + reference: https://github.com/facebookresearch/detectron2/blob/main/datasets/prepare_panoptic_fpn.py#L18 # noqa + + Args: + panoptic: numpy.array, panoptic label, we assume it is already + converted from rgb image by panopticapi.utils.rgb2id. + segments_info (list[dict]): see detectron2 documentation of "Use Custom Datasets". + + Returns: + A dictionary with fields: + - sem_seg: Tensor, semantic label, shape=(H, W). + - center: Tensor, center heatmap, shape=(H, W). + - center_points: List, center coordinates, with tuple + (y-coord, x-coord). + - offset: Tensor, offset, shape=(2, H, W), first dim is + (offset_y, offset_x). + - sem_seg_weights: Tensor, loss weight for semantic prediction, + shape=(H, W). + - center_weights: Tensor, ignore region of center prediction, + shape=(H, W), used as weights for center regression 0 is + ignore, 1 is has instance. Multiply this mask to loss. + - offset_weights: Tensor, ignore region of offset prediction, + shape=(H, W), used as weights for offset regression 0 is + ignore, 1 is has instance. Multiply this mask to loss. + """ + height, width = panoptic.shape[0], panoptic.shape[1] + semantic = np.zeros_like(panoptic, dtype=np.uint8) + self.ignore_label + center = np.zeros((height, width), dtype=np.float32) + center_pts = [] + offset = np.zeros((2, height, width), dtype=np.float32) + y_coord, x_coord = np.meshgrid( + np.arange(height, dtype=np.float32), np.arange(width, dtype=np.float32), indexing="ij" + ) + # Generate pixel-wise loss weights + semantic_weights = np.ones_like(panoptic, dtype=np.uint8) + # 0: ignore, 1: has instance + # three conditions for a region to be ignored for instance branches: + # (1) It is labeled as `ignore_label` + # (2) It is crowd region (iscrowd=1) + # (3) (Optional) It is stuff region (for offset branch) + center_weights = np.zeros_like(panoptic, dtype=np.uint8) + offset_weights = np.zeros_like(panoptic, dtype=np.uint8) + for seg in segments_info: + cat_id = seg["category_id"] + if not (self.ignore_crowd_in_semantic and seg["iscrowd"]): + semantic[panoptic == seg["id"]] = cat_id + if not seg["iscrowd"]: + # Ignored regions are not in `segments_info`. + # Handle crowd region. + center_weights[panoptic == seg["id"]] = 1 + if not self.ignore_stuff_in_offset or cat_id in self.thing_ids: + offset_weights[panoptic == seg["id"]] = 1 + if cat_id in self.thing_ids: + # find instance center + mask_index = np.where(panoptic == seg["id"]) + if len(mask_index[0]) == 0: + # the instance is completely cropped + continue + + # Find instance area + ins_area = len(mask_index[0]) + if ins_area < self.small_instance_area: + semantic_weights[panoptic == seg["id"]] = self.small_instance_weight + + center_y, center_x = np.mean(mask_index[0]), np.mean(mask_index[1]) + center_pts.append([center_y, center_x]) + + # generate center heatmap + y, x = int(round(center_y)), int(round(center_x)) + sigma = self.sigma + # upper left + ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1)) + # bottom right + br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2)) + + # start and end indices in default Gaussian image + gaussian_x0, gaussian_x1 = max(0, -ul[0]), min(br[0], width) - ul[0] + gaussian_y0, gaussian_y1 = max(0, -ul[1]), min(br[1], height) - ul[1] + + # start and end indices in center heatmap image + center_x0, center_x1 = max(0, ul[0]), min(br[0], width) + center_y0, center_y1 = max(0, ul[1]), min(br[1], height) + center[center_y0:center_y1, center_x0:center_x1] = np.maximum( + center[center_y0:center_y1, center_x0:center_x1], + self.g[gaussian_y0:gaussian_y1, gaussian_x0:gaussian_x1], + ) + + # generate offset (2, h, w) -> (y-dir, x-dir) + offset[0][mask_index] = center_y - y_coord[mask_index] + offset[1][mask_index] = center_x - x_coord[mask_index] + + center_weights = center_weights[None] + offset_weights = offset_weights[None] + return dict( + sem_seg=torch.as_tensor(semantic.astype("long")), + center=torch.as_tensor(center.astype(np.float32)), + center_points=center_pts, + offset=torch.as_tensor(offset.astype(np.float32)), + sem_seg_weights=torch.as_tensor(semantic_weights.astype(np.float32)), + center_weights=torch.as_tensor(center_weights.astype(np.float32)), + offset_weights=torch.as_tensor(offset_weights.astype(np.float32)), + ) diff --git a/data_processing/detectron2/projects/Panoptic-DeepLab/train_net.py b/data_processing/detectron2/projects/Panoptic-DeepLab/train_net.py new file mode 100644 index 0000000..780764f --- /dev/null +++ b/data_processing/detectron2/projects/Panoptic-DeepLab/train_net.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +Panoptic-DeepLab Training Script. +This script is a simplified version of the training script in detectron2/tools. +""" + +import os +import torch + +import detectron2.data.transforms as T +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog, build_detection_train_loader +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, +) +from detectron2.projects.deeplab import build_lr_scheduler +from detectron2.projects.panoptic_deeplab import ( + PanopticDeeplabDatasetMapper, + add_panoptic_deeplab_config, +) +from detectron2.solver import get_default_optimizer_params +from detectron2.solver.build import maybe_add_gradient_clipping + + +def build_sem_seg_train_aug(cfg): + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append(T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + augs.append(T.RandomFlip()) + return augs + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains a number pre-defined logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can use the cleaner + "SimpleTrainer", or write your own training loop. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED: + return None + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["cityscapes_panoptic_seg", "coco_panoptic_seg"]: + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_panoptic_seg": + evaluator_list.append(CityscapesSemSegEvaluator(dataset_name)) + evaluator_list.append(CityscapesInstanceEvaluator(dataset_name)) + if evaluator_type == "coco_panoptic_seg": + # `thing_classes` in COCO panoptic metadata includes both thing and + # stuff classes for visualization. COCOEvaluator requires metadata + # which only contains thing classes, thus we map the name of + # panoptic datasets to their corresponding instance datasets. + dataset_name_mapper = { + "coco_2017_val_panoptic": "coco_2017_val", + "coco_2017_val_100_panoptic": "coco_2017_val_100", + } + evaluator_list.append( + COCOEvaluator(dataset_name_mapper[dataset_name], output_dir=output_folder) + ) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format( + dataset_name, evaluator_type + ) + ) + elif len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + @classmethod + def build_train_loader(cls, cfg): + mapper = PanopticDeeplabDatasetMapper(cfg, augmentations=build_sem_seg_train_aug(cfg)) + return build_detection_train_loader(cfg, mapper=mapper) + + @classmethod + def build_lr_scheduler(cls, cfg, optimizer): + """ + It now calls :func:`detectron2.solver.build_lr_scheduler`. + Overwrite it if you'd like a different scheduler. + """ + return build_lr_scheduler(cfg, optimizer) + + @classmethod + def build_optimizer(cls, cfg, model): + """ + Build an optimizer from config. + """ + params = get_default_optimizer_params( + model, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, + ) + + optimizer_type = cfg.SOLVER.OPTIMIZER + if optimizer_type == "SGD": + return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( + params, + cfg.SOLVER.BASE_LR, + momentum=cfg.SOLVER.MOMENTUM, + nesterov=cfg.SOLVER.NESTEROV, + ) + elif optimizer_type == "ADAM": + return maybe_add_gradient_clipping(cfg, torch.optim.Adam)(params, cfg.SOLVER.BASE_LR) + else: + raise NotImplementedError(f"no optimizer type {optimizer_type}") + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_panoptic_deeplab_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/PointRend/README.md b/data_processing/detectron2/projects/PointRend/README.md new file mode 100644 index 0000000..79d75d5 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/README.md @@ -0,0 +1,167 @@ +# PointRend: Image Segmentation as Rendering + +Alexander Kirillov, Yuxin Wu, Kaiming He, Ross Girshick + +[[`arXiv`](https://arxiv.org/abs/1912.08193)] [[`BibTeX`](#CitingPointRend)] + +
+ +

+ +In this repository, we release code for PointRend in Detectron2. PointRend can be flexibly applied to both instance and semantic segmentation tasks by building on top of existing state-of-the-art models. + +## Quick start and visualization + +This [Colab Notebook](https://colab.research.google.com/drive/1isGPL5h5_cKoPPhVL9XhMokRtHDvmMVL) tutorial contains examples of PointRend usage and visualizations of its point sampling stages. + +## Training + +To train a model with 8 GPUs run: +```bash +cd /path/to/detectron2/projects/PointRend +python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly: +```bash +cd /path/to/detectron2/projects/PointRend +python train_net.py --config-file configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint +``` + +# Pretrained Models + +## Instance Segmentation +#### COCO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Mask
head
Backbonelr
sched
Output
resolution
mask
AP
mask
AP*
model iddownload
PointRendR50-FPN224×22436.239.7164254221model | metrics
PointRendR50-FPN224×22438.341.6164955410model | metrics
PointRendR101-FPN224×22440.143.8model | metrics
PointRendX101-FPN224×22441.144.7model | metrics
+ +AP* is COCO mask AP evaluated against the higher-quality LVIS annotations; see the paper for details. +Run `python detectron2/datasets/prepare_cocofied_lvis.py` to prepare GT files for AP* evaluation. +Since LVIS annotations are not exhaustive, `lvis-api` and not `cocoapi` should be used to evaluate AP*. + +#### Cityscapes +Cityscapes model is trained with ImageNet pretraining. + + + + + + + + + + + + + + + + + + + + +
Mask
head
Backbonelr
sched
Output
resolution
mask
AP
model iddownload
PointRendR50-FPN224×22435.9164255101model | metrics
+ + +## Semantic Segmentation + +#### Cityscapes +Cityscapes model is trained with ImageNet pretraining. + + + + + + + + + + + + + + + + + + +
MethodBackboneOutput
resolution
mIoUmodel iddownload
SemanticFPN + PointRendR101-FPN1024×204878.9202576688model | metrics
+ +## Citing PointRend + +If you use PointRend, please use the following BibTeX entry. + +```BibTeX +@InProceedings{kirillov2019pointrend, + title={{PointRend}: Image Segmentation as Rendering}, + author={Alexander Kirillov and Yuxin Wu and Kaiming He and Ross Girshick}, + journal={ArXiv:1912.08193}, + year={2019} +} +``` + +## Citing Implicit PointRend + +If you use Implicit PointRend, please use the following BibTeX entry. + +```BibTeX +@InProceedings{cheng2021pointly, + title={Pointly-Supervised Instance Segmentation, + author={Bowen Cheng and Omkar Parkhi and Alexander Kirillov}, + journal={ArXiv}, + year={2021} +} +``` diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-Implicit-PointRend.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-Implicit-PointRend.yaml new file mode 100644 index 0000000..5ebafb3 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-Implicit-PointRend.yaml @@ -0,0 +1,25 @@ +_BASE_: "../../../../configs/Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: true + ROI_MASK_HEAD: + NAME: "ImplicitPointRendMaskHead" + POOLER_TYPE: "" # No RoI pooling, let the head process image features directly + FC_DIM: 1024 + NUM_FC: 2 + POINT_HEAD: + NAME: "ImplicitPointHead" + FC_DIM: 256 + NUM_FC: 3 + IN_FEATURES: ["p2"] + NUM_CLASSES: 80 + CLS_AGNOSTIC_MASK: False + TRAIN_NUM_POINTS: 196 + SUBDIVISION_STEPS: 3 + SUBDIVISION_NUM_POINTS: 784 + IMPLICIT_POINTREND: + IMAGE_FEATURE_ENABLED: True + POS_ENC_ENABLED: True + PARAMS_L2_REGULARIZER: 0.00001 +INPUT: + # PointRend for instance segmentation does not work with "polygon" mask_format. + MASK_FORMAT: "bitmask" diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml new file mode 100644 index 0000000..e68e707 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml @@ -0,0 +1,20 @@ +_BASE_: "../../../../configs/Base-RCNN-FPN.yaml" +MODEL: + MASK_ON: true + ROI_BOX_HEAD: + TRAIN_ON_PRED_BOXES: True + ROI_MASK_HEAD: + POOLER_TYPE: "" # No RoI pooling, let the head process image features directly + NAME: "PointRendMaskHead" + FC_DIM: 1024 + NUM_FC: 2 + OUTPUT_SIDE_RESOLUTION: 7 + IN_FEATURES: ["p2"] # for the coarse mask head + POINT_HEAD_ON: True + POINT_HEAD: + FC_DIM: 256 + NUM_FC: 3 + IN_FEATURES: ["p2"] +INPUT: + # PointRend for instance segmentation does not work with "polygon" mask_format. + MASK_FORMAT: "bitmask" diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_1x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_1x_coco.yaml new file mode 100644 index 0000000..ba35c24 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_1x_coco.yaml @@ -0,0 +1,8 @@ +_BASE_: "Base-Implicit-PointRend.yaml" +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl + RESNETS: + DEPTH: 50 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_3x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_3x_coco.yaml new file mode 100644 index 0000000..884236d --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_3x_coco.yaml @@ -0,0 +1,11 @@ +_BASE_: "Base-Implicit-PointRend.yaml" +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_101_FPN_3x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_101_FPN_3x_coco.yaml new file mode 100644 index 0000000..4269130 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_101_FPN_3x_coco.yaml @@ -0,0 +1,12 @@ +_BASE_: Base-PointRend-RCNN-FPN.yaml +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl + MASK_ON: true + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml new file mode 100644 index 0000000..0402d6d --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml @@ -0,0 +1,22 @@ +_BASE_: Base-PointRend-RCNN-FPN.yaml +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl + RESNETS: + DEPTH: 50 + ROI_HEADS: + NUM_CLASSES: 8 + POINT_HEAD: + NUM_CLASSES: 8 +DATASETS: + TEST: ("cityscapes_fine_instance_seg_val",) + TRAIN: ("cityscapes_fine_instance_seg_train",) +SOLVER: + BASE_LR: 0.01 + IMS_PER_BATCH: 8 + MAX_ITER: 24000 + STEPS: (18000,) +INPUT: + MAX_SIZE_TEST: 2048 + MAX_SIZE_TRAIN: 2048 + MIN_SIZE_TEST: 1024 + MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024) diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml new file mode 100644 index 0000000..0249b49 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml @@ -0,0 +1,8 @@ +_BASE_: Base-PointRend-RCNN-FPN.yaml +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl + RESNETS: + DEPTH: 50 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml new file mode 100644 index 0000000..a571b4c --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml @@ -0,0 +1,12 @@ +_BASE_: Base-PointRend-RCNN-FPN.yaml +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") + diff --git a/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml new file mode 100644 index 0000000..85d26f3 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml @@ -0,0 +1,16 @@ +_BASE_: Base-PointRend-RCNN-FPN.yaml +MODEL: + MASK_ON: True + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +# To add COCO AP evaluation against the higher-quality LVIS annotations. +# DATASETS: +# TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") diff --git a/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml b/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml new file mode 100644 index 0000000..9b7a1b4 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml @@ -0,0 +1,20 @@ +_BASE_: "../../../../configs/Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "SemanticSegmentor" + BACKBONE: + FREEZE_AT: 0 + SEM_SEG_HEAD: + NAME: "PointRendSemSegHead" + POINT_HEAD: + NUM_CLASSES: 54 + FC_DIM: 256 + NUM_FC: 3 + IN_FEATURES: ["p2"] + TRAIN_NUM_POINTS: 1024 + SUBDIVISION_STEPS: 2 + SUBDIVISION_NUM_POINTS: 8192 + COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead" + COARSE_PRED_EACH_LAYER: False +DATASETS: + TRAIN: ("coco_2017_train_panoptic_stuffonly",) + TEST: ("coco_2017_val_panoptic_stuffonly",) diff --git a/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml b/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml new file mode 100644 index 0000000..6be11fa --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml @@ -0,0 +1,33 @@ +_BASE_: Base-PointRend-Semantic-FPN.yaml +MODEL: + WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl + RESNETS: + DEPTH: 101 + SEM_SEG_HEAD: + NUM_CLASSES: 19 + POINT_HEAD: + NUM_CLASSES: 19 + TRAIN_NUM_POINTS: 2048 + SUBDIVISION_NUM_POINTS: 8192 +DATASETS: + TRAIN: ("cityscapes_fine_sem_seg_train",) + TEST: ("cityscapes_fine_sem_seg_val",) +SOLVER: + BASE_LR: 0.01 + STEPS: (40000, 55000) + MAX_ITER: 65000 + IMS_PER_BATCH: 32 +INPUT: + MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048) + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 1024 + MAX_SIZE_TRAIN: 4096 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (512, 1024) + SINGLE_CATEGORY_MAX_AREA: 0.75 + COLOR_AUG_SSD: True +DATALOADER: + NUM_WORKERS: 10 diff --git a/data_processing/detectron2/projects/PointRend/point_rend/__init__.py b/data_processing/detectron2/projects/PointRend/point_rend/__init__.py new file mode 100644 index 0000000..e3050cb --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import add_pointrend_config +from .mask_head import PointRendMaskHead, ImplicitPointRendMaskHead +from .semantic_seg import PointRendSemSegHead +from .color_augmentation import ColorAugSSDTransform + +from . import roi_heads as _ # only registration diff --git a/data_processing/detectron2/projects/PointRend/point_rend/color_augmentation.py b/data_processing/detectron2/projects/PointRend/point_rend/color_augmentation.py new file mode 100644 index 0000000..cdcb051 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/color_augmentation.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import random +import cv2 +from fvcore.transforms.transform import Transform + + +class ColorAugSSDTransform(Transform): + """ + A color related data augmentation used in Single Shot Multibox Detector (SSD). + + Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, + Scott Reed, Cheng-Yang Fu, Alexander C. Berg. + SSD: Single Shot MultiBox Detector. ECCV 2016. + + Implementation based on: + + https://github.com/weiliu89/caffe/blob + /4817bf8b4200b35ada8ed0dc378dceaf38c539e4 + /src/caffe/util/im_transforms.cpp + + https://github.com/chainer/chainercv/blob + /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv + /links/model/ssd/transforms.py + """ + + def __init__( + self, + img_format, + brightness_delta=32, + contrast_low=0.5, + contrast_high=1.5, + saturation_low=0.5, + saturation_high=1.5, + hue_delta=18, + ): + super().__init__() + assert img_format in ["BGR", "RGB"] + self.is_rgb = img_format == "RGB" + del img_format + self._set_attributes(locals()) + + def apply_coords(self, coords): + return coords + + def apply_segmentation(self, segmentation): + return segmentation + + def apply_image(self, img, interp=None): + if self.is_rgb: + img = img[:, :, [2, 1, 0]] + img = self.brightness(img) + if random.randrange(2): + img = self.contrast(img) + img = self.saturation(img) + img = self.hue(img) + else: + img = self.saturation(img) + img = self.hue(img) + img = self.contrast(img) + if self.is_rgb: + img = img[:, :, [2, 1, 0]] + return img + + def convert(self, img, alpha=1, beta=0): + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img): + if random.randrange(2): + return self.convert( + img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) + ) + return img + + def contrast(self, img): + if random.randrange(2): + return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high)) + return img + + def saturation(self, img): + if random.randrange(2): + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + img[:, :, 1] = self.convert( + img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high) + ) + return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + return img + + def hue(self, img): + if random.randrange(2): + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + img[:, :, 0] = ( + img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta) + ) % 180 + return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + return img diff --git a/data_processing/detectron2/projects/PointRend/point_rend/config.py b/data_processing/detectron2/projects/PointRend/point_rend/config.py new file mode 100644 index 0000000..a02c782 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/config.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.config import CfgNode as CN + + +def add_pointrend_config(cfg): + """ + Add config for PointRend. + """ + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Color augmentatition from SSD paper for semantic segmentation model during training. + cfg.INPUT.COLOR_AUG_SSD = False + + # Names of the input feature maps to be used by a coarse mask head. + cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",) + cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024 + cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2 + # The side size of a coarse mask head prediction. + cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7 + # True if point head is used. + cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False + + cfg.MODEL.POINT_HEAD = CN() + cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead" + cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80 + # Names of the input feature maps to be used by a mask point head. + cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",) + # Number of points sampled during training for a mask point head. + cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14 + # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the + # original paper. + cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3 + # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in + # the original paper. + cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75 + # Number of subdivision steps during inference. + cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5 + # Maximum number of points selected at each subdivision step (N). + cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28 + cfg.MODEL.POINT_HEAD.FC_DIM = 256 + cfg.MODEL.POINT_HEAD.NUM_FC = 3 + cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False + # If True, then coarse prediction features are used as inout for each layer in PointRend's MLP. + cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True + cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead" + + """ + Add config for Implicit PointRend. + """ + cfg.MODEL.IMPLICIT_POINTREND = CN() + + cfg.MODEL.IMPLICIT_POINTREND.IMAGE_FEATURE_ENABLED = True + cfg.MODEL.IMPLICIT_POINTREND.POS_ENC_ENABLED = True + + cfg.MODEL.IMPLICIT_POINTREND.PARAMS_L2_REGULARIZER = 0.00001 diff --git a/data_processing/detectron2/projects/PointRend/point_rend/mask_head.py b/data_processing/detectron2/projects/PointRend/point_rend/mask_head.py new file mode 100644 index 0000000..46dd647 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/mask_head.py @@ -0,0 +1,435 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +import numpy as np +from typing import Dict, List, Tuple +import fvcore.nn.weight_init as weight_init +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate +from detectron2.modeling import ROI_MASK_HEAD_REGISTRY +from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss +from detectron2.structures import Boxes + +from .point_features import ( + generate_regular_grid_point_coords, + get_point_coords_wrt_image, + get_uncertain_point_coords_on_grid, + get_uncertain_point_coords_with_randomness, + point_sample, + point_sample_fine_grained_features, + sample_point_labels, +) +from .point_head import build_point_head, roi_mask_point_loss + + +def calculate_uncertainty(logits, classes): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + classes (list): A list of length R that contains either predicted of ground truth class + for eash predicted mask. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + if logits.shape[1] == 1: + gt_class_logits = logits.clone() + else: + gt_class_logits = logits[ + torch.arange(logits.shape[0], device=logits.device), classes + ].unsqueeze(1) + return -(torch.abs(gt_class_logits)) + + +class ConvFCHead(nn.Module): + """ + A mask head with fully connected layers. Given pooled features it first reduces channels and + spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously + to the standard box head. + """ + + _version = 2 + + @configurable + def __init__( + self, input_shape: ShapeSpec, *, conv_dim: int, fc_dims: List[int], output_shape: Tuple[int] + ): + """ + Args: + conv_dim: the output dimension of the conv layers + fc_dims: a list of N>0 integers representing the output dimensions of N FC layers + output_shape: shape of the output mask prediction + """ + super().__init__() + + # fmt: off + input_channels = input_shape.channels + input_h = input_shape.height + input_w = input_shape.width + self.output_shape = output_shape + # fmt: on + + self.conv_layers = [] + if input_channels > conv_dim: + self.reduce_channel_dim_conv = Conv2d( + input_channels, + conv_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True, + activation=F.relu, + ) + self.conv_layers.append(self.reduce_channel_dim_conv) + + self.reduce_spatial_dim_conv = Conv2d( + conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu + ) + self.conv_layers.append(self.reduce_spatial_dim_conv) + + input_dim = conv_dim * input_h * input_w + input_dim //= 4 + + self.fcs = [] + for k, fc_dim in enumerate(fc_dims): + fc = nn.Linear(input_dim, fc_dim) + self.add_module("fc{}".format(k + 1), fc) + self.fcs.append(fc) + input_dim = fc_dim + + output_dim = int(np.prod(self.output_shape)) + + self.prediction = nn.Linear(fc_dims[-1], output_dim) + # use normal distribution initialization for mask prediction layer + nn.init.normal_(self.prediction.weight, std=0.001) + nn.init.constant_(self.prediction.bias, 0) + + for layer in self.conv_layers: + weight_init.c2_msra_fill(layer) + for layer in self.fcs: + weight_init.c2_xavier_fill(layer) + + @classmethod + def from_config(cls, cfg, input_shape): + output_shape = ( + cfg.MODEL.ROI_HEADS.NUM_CLASSES, + cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, + cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, + ) + fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM + num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC + ret = dict( + input_shape=input_shape, + conv_dim=cfg.MODEL.ROI_MASK_HEAD.CONV_DIM, + fc_dims=[fc_dim] * num_fc, + output_shape=output_shape, + ) + return ret + + def forward(self, x): + N = x.shape[0] + for layer in self.conv_layers: + x = layer(x) + x = torch.flatten(x, start_dim=1) + for layer in self.fcs: + x = F.relu(layer(x)) + output_shape = [N] + list(self.output_shape) + return self.prediction(x).view(*output_shape) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + logger = logging.getLogger(__name__) + logger.warning( + "Weight format of PointRend models have changed! " + "Applying automatic conversion now ..." + ) + for k in list(state_dict.keys()): + newk = k + if k.startswith(prefix + "coarse_mask_fc"): + newk = k.replace(prefix + "coarse_mask_fc", prefix + "fc") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + + +@ROI_MASK_HEAD_REGISTRY.register() +class PointRendMaskHead(nn.Module): + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} + # point head + self._init_point_head(cfg, input_shape) + # coarse mask head + self.roi_pooler_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES + self.roi_pooler_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} + in_channels = np.sum([input_shape[f].channels for f in self.roi_pooler_in_features]) + self._init_roi_head( + cfg, + ShapeSpec( + channels=in_channels, + width=self.roi_pooler_size, + height=self.roi_pooler_size, + ), + ) + + def _init_roi_head(self, cfg, input_shape): + self.coarse_head = ConvFCHead(cfg, input_shape) + + def _init_point_head(self, cfg, input_shape): + # fmt: off + self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON + if not self.mask_point_on: + return + assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES + self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES + self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS + self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO + self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO + # next three parameters are use in the adaptive subdivions inference procedure + self.mask_point_subdivision_init_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION + self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS + self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS + # fmt: on + + in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) + self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) + + # An optimization to skip unused subdivision steps: if after subdivision, all pixels on + # the mask will be selected and recomputed anyway, we should just double our init_resolution + while ( + 4 * self.mask_point_subdivision_init_resolution**2 + <= self.mask_point_subdivision_num_points + ): + self.mask_point_subdivision_init_resolution *= 2 + self.mask_point_subdivision_steps -= 1 + + def forward(self, features, instances): + """ + Args: + features (dict[str, Tensor]): a dict of image-level features + instances (list[Instances]): proposals in training; detected + instances in inference + """ + if self.training: + proposal_boxes = [x.proposal_boxes for x in instances] + coarse_mask = self.coarse_head(self._roi_pooler(features, proposal_boxes)) + losses = {"loss_mask": mask_rcnn_loss(coarse_mask, instances)} + if not self.mask_point_on: + return losses + + point_coords, point_labels = self._sample_train_points(coarse_mask, instances) + point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) + point_logits = self._get_point_logits( + point_fine_grained_features, point_coords, coarse_mask + ) + losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) + return losses + else: + pred_boxes = [x.pred_boxes for x in instances] + coarse_mask = self.coarse_head(self._roi_pooler(features, pred_boxes)) + return self._subdivision_inference(features, coarse_mask, instances) + + def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): + """ + Extract per-box feature. This is similar to RoIAlign(sampling_ratio=1) except: + 1. It's implemented by point_sample + 2. It pools features across all levels and concat them, while typically + RoIAlign select one level for every box. However in the config we only use + one level (p2) so there is no difference. + + Returns: + Tensor of shape (R, C, pooler_size, pooler_size) where R is the total number of boxes + """ + features_list = [features[k] for k in self.roi_pooler_in_features] + features_scales = [self._feature_scales[k] for k in self.roi_pooler_in_features] + + num_boxes = sum(x.tensor.size(0) for x in boxes) + output_size = self.roi_pooler_size + point_coords = generate_regular_grid_point_coords(num_boxes, output_size, boxes[0].device) + # For regular grids of points, this function is equivalent to `len(features_list)' calls + # of `ROIAlign` (with `SAMPLING_RATIO=1`), and concat the results. + roi_features, _ = point_sample_fine_grained_features( + features_list, features_scales, boxes, point_coords + ) + return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) + + def _sample_train_points(self, coarse_mask, instances): + assert self.training + gt_classes = cat([x.gt_classes for x in instances]) + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + coarse_mask, + lambda logits: calculate_uncertainty(logits, gt_classes), + self.mask_point_train_num_points, + self.mask_point_oversample_ratio, + self.mask_point_importance_sample_ratio, + ) + # sample point_labels + proposal_boxes = [x.proposal_boxes for x in instances] + cat_boxes = Boxes.cat(proposal_boxes) + point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) + point_labels = sample_point_labels(instances, point_coords_wrt_image) + return point_coords, point_labels + + def _point_pooler(self, features, proposal_boxes, point_coords): + point_features_list = [features[k] for k in self.mask_point_in_features] + point_features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] + # sample image-level features + point_fine_grained_features, _ = point_sample_fine_grained_features( + point_features_list, point_features_scales, proposal_boxes, point_coords + ) + return point_fine_grained_features + + def _get_point_logits(self, point_fine_grained_features, point_coords, coarse_mask): + coarse_features = point_sample(coarse_mask, point_coords, align_corners=False) + point_logits = self.point_head(point_fine_grained_features, coarse_features) + return point_logits + + def _subdivision_inference(self, features, mask_representations, instances): + assert not self.training + + pred_boxes = [x.pred_boxes for x in instances] + pred_classes = cat([x.pred_classes for x in instances]) + + mask_logits = None + # +1 here to include an initial step to generate the coarsest mask + # prediction with init_resolution, when mask_logits is None. + # We compute initial mask by sampling on a regular grid. coarse_mask + # can be used as initial mask as well, but it's typically very low-res + # so it will be completely overwritten during subdivision anyway. + for _ in range(self.mask_point_subdivision_steps + 1): + if mask_logits is None: + point_coords = generate_regular_grid_point_coords( + pred_classes.size(0), + self.mask_point_subdivision_init_resolution, + pred_boxes[0].device, + ) + else: + mask_logits = interpolate( + mask_logits, scale_factor=2, mode="bilinear", align_corners=False + ) + uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) + point_indices, point_coords = get_uncertain_point_coords_on_grid( + uncertainty_map, self.mask_point_subdivision_num_points + ) + + # Run the point head for every point in point_coords + fine_grained_features = self._point_pooler(features, pred_boxes, point_coords) + point_logits = self._get_point_logits( + fine_grained_features, point_coords, mask_representations + ) + + if mask_logits is None: + # Create initial mask_logits using point_logits on this regular grid + R, C, _ = point_logits.shape + mask_logits = point_logits.reshape( + R, + C, + self.mask_point_subdivision_init_resolution, + self.mask_point_subdivision_init_resolution, + ) + # The subdivision code will fail with the empty list of boxes + if len(pred_classes) == 0: + mask_rcnn_inference(mask_logits, instances) + return instances + else: + # Put point predictions to the right places on the upsampled grid. + R, C, H, W = mask_logits.shape + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) + mask_logits = ( + mask_logits.reshape(R, C, H * W) + .scatter_(2, point_indices, point_logits) + .view(R, C, H, W) + ) + mask_rcnn_inference(mask_logits, instances) + return instances + + +@ROI_MASK_HEAD_REGISTRY.register() +class ImplicitPointRendMaskHead(PointRendMaskHead): + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__(cfg, input_shape) + + def _init_roi_head(self, cfg, input_shape): + assert hasattr(self, "num_params"), "Please initialize point_head first!" + self.parameter_head = ConvFCHead(cfg, input_shape, output_shape=(self.num_params,)) + self.regularizer = cfg.MODEL.IMPLICIT_POINTREND.PARAMS_L2_REGULARIZER + + def _init_point_head(self, cfg, input_shape): + # fmt: off + self.mask_point_on = True # always on + assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES + self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES + self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS + # next two parameters are use in the adaptive subdivions inference procedure + self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS + self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS + # fmt: on + + in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) + self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) + self.num_params = self.point_head.num_params + + # inference parameters + self.mask_point_subdivision_init_resolution = int( + math.sqrt(self.mask_point_subdivision_num_points) + ) + assert ( + self.mask_point_subdivision_init_resolution + * self.mask_point_subdivision_init_resolution + == self.mask_point_subdivision_num_points + ) + + def forward(self, features, instances): + """ + Args: + features (dict[str, Tensor]): a dict of image-level features + instances (list[Instances]): proposals in training; detected + instances in inference + """ + if self.training: + proposal_boxes = [x.proposal_boxes for x in instances] + parameters = self.parameter_head(self._roi_pooler(features, proposal_boxes)) + losses = {"loss_l2": self.regularizer * (parameters**2).mean()} + + point_coords, point_labels = self._uniform_sample_train_points(instances) + point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) + point_logits = self._get_point_logits( + point_fine_grained_features, point_coords, parameters + ) + losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) + return losses + else: + pred_boxes = [x.pred_boxes for x in instances] + parameters = self.parameter_head(self._roi_pooler(features, pred_boxes)) + return self._subdivision_inference(features, parameters, instances) + + def _uniform_sample_train_points(self, instances): + assert self.training + proposal_boxes = [x.proposal_boxes for x in instances] + cat_boxes = Boxes.cat(proposal_boxes) + # uniform sample + point_coords = torch.rand( + len(cat_boxes), self.mask_point_train_num_points, 2, device=cat_boxes.tensor.device + ) + # sample point_labels + point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) + point_labels = sample_point_labels(instances, point_coords_wrt_image) + return point_coords, point_labels + + def _get_point_logits(self, fine_grained_features, point_coords, parameters): + return self.point_head(fine_grained_features, point_coords, parameters) diff --git a/data_processing/detectron2/projects/PointRend/point_rend/point_features.py b/data_processing/detectron2/projects/PointRend/point_rend/point_features.py new file mode 100644 index 0000000..e46f442 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/point_features.py @@ -0,0 +1,259 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch.nn import functional as F + +from detectron2.layers import cat, shapes_to_tensor +from detectron2.structures import BitMasks, Boxes + + +""" +Shape shorthand in this module: + + N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the + number of images for semantic segmenation. + R: number of ROIs, combined over all images, in the minibatch + P: number of points +""" + + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +def generate_regular_grid_point_coords(R, side_size, device): + """ + Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. + + Args: + R (int): The number of grids to sample, one for each region. + side_size (int): The side size of the regular grid. + device (torch.device): Desired device of returned tensor. + + Returns: + (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates + for the regular grids. + """ + aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) + r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) + return r.view(1, -1, 2).expand(R, -1, -1) + + +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio +): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, 2 + ) + if num_random_points > 0: + point_coords = cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords + + +def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): + """ + Find `num_points` most uncertain points from `uncertainty_map` grid. + + Args: + uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty + values for a set of points on a regular H x W grid. + num_points (int): The number of points P to select. + + Returns: + point_indices (Tensor): A tensor of shape (N, P) that contains indices from + [0, H x W) of the most uncertain points. + point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized + coordinates of the most uncertain points from the H x W grid. + """ + R, _, H, W = uncertainty_map.shape + h_step = 1.0 / float(H) + w_step = 1.0 / float(W) + + num_points = min(H * W, num_points) + point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] + point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step + return point_indices, point_coords + + +def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): + """ + Get features from feature maps in `features_list` that correspond to specific point coordinates + inside each bounding box from `boxes`. + + Args: + features_list (list[Tensor]): A list of feature map tensors to get features from. + feature_scales (list[float]): A list of scales for tensors in `features_list`. + boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all + together. + point_coords (Tensor): A tensor of shape (R, P, 2) that contains + [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. + + Returns: + point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled + from all features maps in feature_list for P sampled points for all R boxes in `boxes`. + point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level + coordinates of P points. + """ + cat_boxes = Boxes.cat(boxes) + num_boxes = [b.tensor.size(0) for b in boxes] + + point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) + split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) + + point_features = [] + for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): + point_features_per_image = [] + for idx_feature, feature_map in enumerate(features_list): + h, w = feature_map.shape[-2:] + scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature] + point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device) + point_features_per_image.append( + point_sample( + feature_map[idx_img].unsqueeze(0), + point_coords_scaled.unsqueeze(0), + align_corners=False, + ) + .squeeze(0) + .transpose(1, 0) + ) + point_features.append(cat(point_features_per_image, dim=1)) + + return cat(point_features, dim=0), point_coords_wrt_image + + +def get_point_coords_wrt_image(boxes_coords, point_coords): + """ + Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. + + Args: + boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. + coordinates. + point_coords (Tensor): A tensor of shape (R, P, 2) that contains + [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. + + Returns: + point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains + image-normalized coordinates of P sampled points. + """ + with torch.no_grad(): + point_coords_wrt_image = point_coords.clone() + point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( + boxes_coords[:, None, 2] - boxes_coords[:, None, 0] + ) + point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( + boxes_coords[:, None, 3] - boxes_coords[:, None, 1] + ) + point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] + point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] + return point_coords_wrt_image + + +def sample_point_labels(instances, point_coords): + """ + Sample point labels from ground truth mask given point_coords. + + Args: + instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is + equal to R. The ground-truth gt_masks in each instance will be used to compute labels. + points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of + instances and P is the number of points for each instance. The coordinates are in + the absolute image pixel coordinate space, i.e. [0, H] x [0, W]. + + Returns: + Tensor: A tensor of shape (R, P) that contains the labels of P sampled points. + """ + with torch.no_grad(): + gt_mask_logits = [] + point_coords_splits = torch.split( + point_coords, [len(instances_per_image) for instances_per_image in instances] + ) + for i, instances_per_image in enumerate(instances): + if len(instances_per_image) == 0: + continue + assert isinstance( + instances_per_image.gt_masks, BitMasks + ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." + + gt_bit_masks = instances_per_image.gt_masks.tensor + h, w = instances_per_image.gt_masks.image_size + scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) + points_coord_grid_sample_format = point_coords_splits[i] / scale + gt_mask_logits.append( + point_sample( + gt_bit_masks.to(torch.float32).unsqueeze(1), + points_coord_grid_sample_format, + align_corners=False, + ).squeeze(1) + ) + + point_labels = cat(gt_mask_logits) + return point_labels diff --git a/data_processing/detectron2/projects/PointRend/point_rend/point_head.py b/data_processing/detectron2/projects/PointRend/point_rend/point_head.py new file mode 100644 index 0000000..1786fad --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/point_head.py @@ -0,0 +1,282 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.layers import ShapeSpec, cat +from detectron2.utils.events import get_event_storage +from detectron2.utils.registry import Registry + +POINT_HEAD_REGISTRY = Registry("POINT_HEAD") +POINT_HEAD_REGISTRY.__doc__ = """ +Registry for point heads, which makes prediction for a given set of per-point features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +def roi_mask_point_loss(mask_logits, instances, point_labels): + """ + Compute the point-based loss for instance segmentation mask predictions + given point-wise mask prediction and its corresponding point-wise labels. + Args: + mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images, C is the + number of foreground classes, and P is the number of points sampled for each mask. + The values are logits. + instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th + elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. + The ground-truth labels (class, box, mask, ...) associated with each instance are stored + in fields. + point_labels (Tensor): A tensor of shape (R, P), where R is the total number of + predicted masks and P is the number of points for each mask. + Labels with value of -1 will be ignored. + Returns: + point_loss (Tensor): A scalar tensor containing the loss. + """ + with torch.no_grad(): + cls_agnostic_mask = mask_logits.size(1) == 1 + total_num_masks = mask_logits.size(0) + + gt_classes = [] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + + if not cls_agnostic_mask: + gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) + gt_classes.append(gt_classes_per_image) + + gt_mask_logits = point_labels + point_ignores = point_labels == -1 + if gt_mask_logits.shape[0] == 0: + return mask_logits.sum() * 0 + + assert gt_mask_logits.numel() > 0, gt_mask_logits.shape + + if cls_agnostic_mask: + mask_logits = mask_logits[:, 0] + else: + indices = torch.arange(total_num_masks) + gt_classes = cat(gt_classes, dim=0) + mask_logits = mask_logits[indices, gt_classes] + + # Log the training accuracy (using gt classes and 0.0 threshold for the logits) + mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) + mask_accurate = mask_accurate[~point_ignores] + mask_accuracy = mask_accurate.nonzero().size(0) / max(mask_accurate.numel(), 1.0) + get_event_storage().put_scalar("point/accuracy", mask_accuracy) + + point_loss = F.binary_cross_entropy_with_logits( + mask_logits, gt_mask_logits.to(dtype=torch.float32), weight=~point_ignores, reduction="mean" + ) + return point_loss + + +@POINT_HEAD_REGISTRY.register() +class StandardPointHead(nn.Module): + """ + A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head + takes both fine-grained and coarse prediction features as its input. + """ + + def __init__(self, cfg, input_shape: ShapeSpec): + """ + The following attributes are parsed from config: + fc_dim: the output dimension of each FC layers + num_fc: the number of FC layers + coarse_pred_each_layer: if True, coarse prediction features are concatenated to each + layer's input + """ + super(StandardPointHead, self).__init__() + # fmt: off + num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES + fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM + num_fc = cfg.MODEL.POINT_HEAD.NUM_FC + cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK + self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER + input_channels = input_shape.channels + # fmt: on + + fc_dim_in = input_channels + num_classes + self.fc_layers = [] + for k in range(num_fc): + fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.add_module("fc{}".format(k + 1), fc) + self.fc_layers.append(fc) + fc_dim_in = fc_dim + fc_dim_in += num_classes if self.coarse_pred_each_layer else 0 + + num_mask_classes = 1 if cls_agnostic_mask else num_classes + self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0) + + for layer in self.fc_layers: + weight_init.c2_msra_fill(layer) + # use normal distribution initialization for mask prediction layer + nn.init.normal_(self.predictor.weight, std=0.001) + if self.predictor.bias is not None: + nn.init.constant_(self.predictor.bias, 0) + + def forward(self, fine_grained_features, coarse_features): + x = torch.cat((fine_grained_features, coarse_features), dim=1) + for layer in self.fc_layers: + x = F.relu(layer(x)) + if self.coarse_pred_each_layer: + x = cat((x, coarse_features), dim=1) + return self.predictor(x) + + +@POINT_HEAD_REGISTRY.register() +class ImplicitPointHead(nn.Module): + """ + A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head + takes both fine-grained features and instance-wise MLP parameters as its input. + """ + + def __init__(self, cfg, input_shape: ShapeSpec): + """ + The following attributes are parsed from config: + channels: the output dimension of each FC layers + num_layers: the number of FC layers (including the final prediction layer) + image_feature_enabled: if True, fine-grained image-level features are used + positional_encoding_enabled: if True, positional encoding is used + """ + super(ImplicitPointHead, self).__init__() + # fmt: off + self.num_layers = cfg.MODEL.POINT_HEAD.NUM_FC + 1 + self.channels = cfg.MODEL.POINT_HEAD.FC_DIM + self.image_feature_enabled = cfg.MODEL.IMPLICIT_POINTREND.IMAGE_FEATURE_ENABLED + self.positional_encoding_enabled = cfg.MODEL.IMPLICIT_POINTREND.POS_ENC_ENABLED + self.num_classes = ( + cfg.MODEL.POINT_HEAD.NUM_CLASSES if not cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK else 1 + ) + self.in_channels = input_shape.channels + # fmt: on + + if not self.image_feature_enabled: + self.in_channels = 0 + if self.positional_encoding_enabled: + self.in_channels += 256 + self.register_buffer("positional_encoding_gaussian_matrix", torch.randn((2, 128))) + + assert self.in_channels > 0 + + num_weight_params, num_bias_params = [], [] + assert self.num_layers >= 2 + for l in range(self.num_layers): + if l == 0: + # input layer + num_weight_params.append(self.in_channels * self.channels) + num_bias_params.append(self.channels) + elif l == self.num_layers - 1: + # output layer + num_weight_params.append(self.channels * self.num_classes) + num_bias_params.append(self.num_classes) + else: + # intermediate layer + num_weight_params.append(self.channels * self.channels) + num_bias_params.append(self.channels) + + self.num_weight_params = num_weight_params + self.num_bias_params = num_bias_params + self.num_params = sum(num_weight_params) + sum(num_bias_params) + + def forward(self, fine_grained_features, point_coords, parameters): + # features: [R, channels, K] + # point_coords: [R, K, 2] + num_instances = fine_grained_features.size(0) + num_points = fine_grained_features.size(2) + + if num_instances == 0: + return torch.zeros((0, 1, num_points), device=fine_grained_features.device) + + if self.positional_encoding_enabled: + # locations: [R*K, 2] + locations = 2 * point_coords.reshape(num_instances * num_points, 2) - 1 + locations = locations @ self.positional_encoding_gaussian_matrix.to(locations.device) + locations = 2 * np.pi * locations + locations = torch.cat([torch.sin(locations), torch.cos(locations)], dim=1) + # locations: [R, C, K] + locations = locations.reshape(num_instances, num_points, 256).permute(0, 2, 1) + if not self.image_feature_enabled: + fine_grained_features = locations + else: + fine_grained_features = torch.cat([locations, fine_grained_features], dim=1) + + # features [R, C, K] + mask_feat = fine_grained_features.reshape(num_instances, self.in_channels, num_points) + + weights, biases = self._parse_params( + parameters, + self.in_channels, + self.channels, + self.num_classes, + self.num_weight_params, + self.num_bias_params, + ) + + point_logits = self._dynamic_mlp(mask_feat, weights, biases, num_instances) + point_logits = point_logits.reshape(-1, self.num_classes, num_points) + + return point_logits + + @staticmethod + def _dynamic_mlp(features, weights, biases, num_instances): + assert features.dim() == 3, features.dim() + n_layers = len(weights) + x = features + for i, (w, b) in enumerate(zip(weights, biases)): + x = torch.einsum("nck,ndc->ndk", x, w) + b + if i < n_layers - 1: + x = F.relu(x) + return x + + @staticmethod + def _parse_params( + pred_params, + in_channels, + channels, + num_classes, + num_weight_params, + num_bias_params, + ): + assert pred_params.dim() == 2 + assert len(num_weight_params) == len(num_bias_params) + assert pred_params.size(1) == sum(num_weight_params) + sum(num_bias_params) + + num_instances = pred_params.size(0) + num_layers = len(num_weight_params) + + params_splits = list( + torch.split_with_sizes(pred_params, num_weight_params + num_bias_params, dim=1) + ) + + weight_splits = params_splits[:num_layers] + bias_splits = params_splits[num_layers:] + + for l in range(num_layers): + if l == 0: + # input layer + weight_splits[l] = weight_splits[l].reshape(num_instances, channels, in_channels) + bias_splits[l] = bias_splits[l].reshape(num_instances, channels, 1) + elif l < num_layers - 1: + # intermediate layer + weight_splits[l] = weight_splits[l].reshape(num_instances, channels, channels) + bias_splits[l] = bias_splits[l].reshape(num_instances, channels, 1) + else: + # output layer + weight_splits[l] = weight_splits[l].reshape(num_instances, num_classes, channels) + bias_splits[l] = bias_splits[l].reshape(num_instances, num_classes, 1) + + return weight_splits, bias_splits + + +def build_point_head(cfg, input_channels): + """ + Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`. + """ + head_name = cfg.MODEL.POINT_HEAD.NAME + return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels) diff --git a/data_processing/detectron2/projects/PointRend/point_rend/roi_heads.py b/data_processing/detectron2/projects/PointRend/point_rend/roi_heads.py new file mode 100644 index 0000000..74ccc34 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/roi_heads.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging + +from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads + + +@ROI_HEADS_REGISTRY.register() +class PointRendROIHeads(StandardROIHeads): + """ + Identical to StandardROIHeads, except for some weights conversion code to + handle old models. + """ + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + logger = logging.getLogger(__name__) + logger.warning( + "Weight format of PointRend models have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + for k in list(state_dict.keys()): + newk = k + if k.startswith(prefix + "mask_point_head"): + newk = k.replace(prefix + "mask_point_head", prefix + "mask_head.point_head") + if k.startswith(prefix + "mask_coarse_head"): + newk = k.replace(prefix + "mask_coarse_head", prefix + "mask_head.coarse_head") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + + @classmethod + def _init_mask_head(cls, cfg, input_shape): + if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.NAME != "PointRendMaskHead": + logger = logging.getLogger(__name__) + logger.warning( + "Config of PointRend models have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + assert cfg.MODEL.ROI_MASK_HEAD.NAME == "CoarseMaskHead" + cfg.defrost() + cfg.MODEL.ROI_MASK_HEAD.NAME = "PointRendMaskHead" + cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "" + cfg.freeze() + return super()._init_mask_head(cfg, input_shape) diff --git a/data_processing/detectron2/projects/PointRend/point_rend/semantic_seg.py b/data_processing/detectron2/projects/PointRend/point_rend/semantic_seg.py new file mode 100644 index 0000000..ea65200 --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/point_rend/semantic_seg.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Dict +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.layers import ShapeSpec, cat +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from .point_features import ( + get_uncertain_point_coords_on_grid, + get_uncertain_point_coords_with_randomness, + point_sample, +) +from .point_head import build_point_head + + +def calculate_uncertainty(sem_seg_logits): + """ + For each location of the prediction `sem_seg_logits` we estimate uncerainty as the + difference between top first and top second predicted logits. + + Args: + mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and + C is the number of foreground classes. The values are logits. + + Returns: + scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@SEM_SEG_HEADS_REGISTRY.register() +class PointRendSemSegHead(nn.Module): + """ + A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` + and a point head set in `MODEL.POINT_HEAD.NAME`. + """ + + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + + self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE + + self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( + cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME + )(cfg, input_shape) + self._init_point_head(cfg, input_shape) + + def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): + # fmt: off + assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES + feature_channels = {k: v.channels for k, v in input_shape.items()} + self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES + self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS + self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO + self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO + self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS + self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS + # fmt: on + + in_channels = int(np.sum([feature_channels[f] for f in self.in_features])) + self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) + + def forward(self, features, targets=None): + coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) + + if self.training: + losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) + + with torch.no_grad(): + point_coords = get_uncertain_point_coords_with_randomness( + coarse_sem_seg_logits, + calculate_uncertainty, + self.train_num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) + + fine_grained_features = cat( + [ + point_sample(features[in_feature], point_coords, align_corners=False) + for in_feature in self.in_features + ], + dim=1, + ) + point_logits = self.point_head(fine_grained_features, coarse_features) + point_targets = ( + point_sample( + targets.unsqueeze(1).to(torch.float), + point_coords, + mode="nearest", + align_corners=False, + ) + .squeeze(1) + .to(torch.long) + ) + losses["loss_sem_seg_point"] = F.cross_entropy( + point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value + ) + return None, losses + else: + sem_seg_logits = coarse_sem_seg_logits.clone() + for _ in range(self.subdivision_steps): + sem_seg_logits = F.interpolate( + sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False + ) + uncertainty_map = calculate_uncertainty(sem_seg_logits) + point_indices, point_coords = get_uncertain_point_coords_on_grid( + uncertainty_map, self.subdivision_num_points + ) + fine_grained_features = cat( + [ + point_sample(features[in_feature], point_coords, align_corners=False) + for in_feature in self.in_features + ] + ) + coarse_features = point_sample( + coarse_sem_seg_logits, point_coords, align_corners=False + ) + point_logits = self.point_head(fine_grained_features, coarse_features) + + # put sem seg point predictions to the right places on the upsampled grid. + N, C, H, W = sem_seg_logits.shape + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) + sem_seg_logits = ( + sem_seg_logits.reshape(N, C, H * W) + .scatter_(2, point_indices, point_logits) + .view(N, C, H, W) + ) + return sem_seg_logits, {} diff --git a/data_processing/detectron2/projects/PointRend/train_net.py b/data_processing/detectron2/projects/PointRend/train_net.py new file mode 100644 index 0000000..9ae6f1a --- /dev/null +++ b/data_processing/detectron2/projects/PointRend/train_net.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +PointRend Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" + +import os + +import detectron2.data.transforms as T +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import DatasetMapper, MetadataCatalog, build_detection_train_loader +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + DatasetEvaluators, + LVISEvaluator, + SemSegEvaluator, + verify_results, +) +from detectron2.projects.point_rend import ColorAugSSDTransform, add_pointrend_config + + +def build_sem_seg_train_aug(cfg): + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append( + T.RandomCrop_CategoryAreaConstraint( + cfg.INPUT.CROP.TYPE, + cfg.INPUT.CROP.SIZE, + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, + cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + ) + ) + if cfg.INPUT.COLOR_AUG_SSD: + augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) + augs.append(T.RandomFlip()) + return augs + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains a number pre-defined logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can use the cleaner + "SimpleTrainer", or write your own training loop. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type == "lvis": + return LVISEvaluator(dataset_name, output_dir=output_folder) + if evaluator_type == "coco": + return COCOEvaluator(dataset_name, output_dir=output_folder) + if evaluator_type == "sem_seg": + return SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + if evaluator_type == "cityscapes_instance": + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + return CityscapesSemSegEvaluator(dataset_name) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format( + dataset_name, evaluator_type + ) + ) + if len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + @classmethod + def build_train_loader(cls, cfg): + if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE: + mapper = DatasetMapper(cfg, is_train=True, augmentations=build_sem_seg_train_aug(cfg)) + else: + mapper = None + return build_detection_train_loader(cfg, mapper=mapper) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_pointrend_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/PointSup/README.md b/data_processing/detectron2/projects/PointSup/README.md new file mode 100644 index 0000000..75ce084 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/README.md @@ -0,0 +1,41 @@ +# Pointly-Supervised Instance Segmentation + +Bowen Cheng, Omkar Parkhi, Alexander Kirillov + +[[`arXiv`](https://arxiv.org/abs/2104.06404)] [[`Project`](https://bowenc0221.github.io/point-sup)] [[`BibTeX`](#CitingPointSup)] + +
+ +

+ +## Data preparation +Please follow these steps to prepare your datasets: +1. Follow official Detectron2 instruction to prepare COCO dataset. Set up `DETECTRON2_DATASETS` environment variable to the location of your Detectron2 dataset. +2. Generate 10-points annotations for COCO by running: `python tools/prepare_coco_point_annotations_without_masks.py 10` + +## Training + +To train a model with 8 GPUs run: +```bash +python train_net.py --config-file configs/mask_rcnn_R_50_FPN_3x_point_sup_point_aug_coco.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly: +```bash +python train_net.py --config-file configs/mask_rcnn_R_50_FPN_3x_point_sup_point_aug_coco.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint +``` + +## Citing Pointly-Supervised Instance Segmentation + +If you use PointSup, please use the following BibTeX entry. + +```BibTeX +@article{cheng2021pointly, + title={Pointly-Supervised Instance Segmentation}, + author={Bowen Cheng and Omkar Parkhi and Alexander Kirillov}, + journal={arXiv}, + year={2021} +} +``` diff --git a/data_processing/detectron2/projects/PointSup/configs/implicit_pointrend_R_50_FPN_3x_point_sup_point_aug_coco.yaml b/data_processing/detectron2/projects/PointSup/configs/implicit_pointrend_R_50_FPN_3x_point_sup_point_aug_coco.yaml new file mode 100644 index 0000000..5b3d427 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/configs/implicit_pointrend_R_50_FPN_3x_point_sup_point_aug_coco.yaml @@ -0,0 +1,9 @@ +_BASE_: "../../PointRend/configs/InstanceSegmentation/implicit_pointrend_R_50_FPN_3x_coco.yaml" +MODEL: + ROI_MASK_HEAD: + NAME: "ImplicitPointRendPointSupHead" +INPUT: + POINT_SUP: True + SAMPLE_POINTS: 5 +DATASETS: + TRAIN: ("coco_2017_train_points_n10_v1_without_masks",) diff --git a/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_coco.yaml b/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_coco.yaml new file mode 100644 index 0000000..157e384 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_coco.yaml @@ -0,0 +1,15 @@ +_BASE_: "../../../configs/Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsamplePointSupHead" +INPUT: + POINT_SUP: True +DATASETS: + TRAIN: ("coco_2017_train_points_n10_v1_without_masks",) +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_point_aug_coco.yaml b/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_point_aug_coco.yaml new file mode 100644 index 0000000..4b11224 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/configs/mask_rcnn_R_50_FPN_3x_point_sup_point_aug_coco.yaml @@ -0,0 +1,3 @@ +_BASE_: "mask_rcnn_R_50_FPN_3x_point_sup_coco.yaml" +INPUT: + SAMPLE_POINTS: 5 diff --git a/data_processing/detectron2/projects/PointSup/point_sup/__init__.py b/data_processing/detectron2/projects/PointSup/point_sup/__init__.py new file mode 100644 index 0000000..510e381 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from . import register_point_annotations +from .config import add_point_sup_config +from .dataset_mapper import PointSupDatasetMapper +from .mask_head import MaskRCNNConvUpsamplePointSupHead +from .point_utils import get_point_coords_from_point_annotation diff --git a/data_processing/detectron2/projects/PointSup/point_sup/config.py b/data_processing/detectron2/projects/PointSup/point_sup/config.py new file mode 100644 index 0000000..5e00b78 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/config.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + + +def add_point_sup_config(cfg): + """ + Add config for point supervision. + """ + # Use point annotation + cfg.INPUT.POINT_SUP = False + # Sample only part of points in each iteration. + # Default: 0, use all available points. + cfg.INPUT.SAMPLE_POINTS = 0 diff --git a/data_processing/detectron2/projects/PointSup/point_sup/dataset_mapper.py b/data_processing/detectron2/projects/PointSup/point_sup/dataset_mapper.py new file mode 100644 index 0000000..52b9bd4 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/dataset_mapper.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy +import logging +import numpy as np +from typing import List, Union +import torch + +import detectron2.data.detection_utils as utils +import detectron2.data.transforms as T +from detectron2.config import configurable + +from .detection_utils import annotations_to_instances, transform_instance_annotations + +__all__ = [ + "PointSupDatasetMapper", +] + + +class PointSupDatasetMapper: + """ + The callable currently does the following: + 1. Read the image from "file_name" + 2. Applies transforms to the image and annotations + 3. Prepare data and annotations to Tensor and :class:`Instances` + """ + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + # Extra data augmentation for point supervision + sample_points: int = 0, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + sample_points: subsample points at each iteration + """ + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.sample_points = sample_points + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + logger.info(f"Point Augmentations used in {mode}: sample {sample_points} points") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + raise ValueError("Crop augmentation not supported to point supervision.") + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "sample_points": cfg.INPUT.SAMPLE_POINTS, + } + + return ret + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + utils.check_image_size(dataset_dict, image) + + aug_input = T.AugInput(image) + transforms = self.augmentations(aug_input) + image = aug_input.image + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + if not self.is_train: + dataset_dict.pop("annotations", None) + return dataset_dict + + if "annotations" in dataset_dict: + # Maps points from the closed interval [0, image_size - 1] on discrete + # image coordinates to the half-open interval [x1, x2) on continuous image + # coordinates. We use the continuous-discrete conversion from Heckbert + # 1990 ("What is the coordinate of a pixel?"): d = floor(c) and c = d + 0.5, + # where d is a discrete coordinate and c is a continuous coordinate. + for ann in dataset_dict["annotations"]: + point_coords_wrt_image = np.array(ann["point_coords"]).astype(np.float) + point_coords_wrt_image = point_coords_wrt_image + 0.5 + ann["point_coords"] = point_coords_wrt_image + + annos = [ + # also need to transform point coordinates + transform_instance_annotations( + obj, + transforms, + image_shape, + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = annotations_to_instances( + annos, + image_shape, + sample_points=self.sample_points, + ) + + dataset_dict["instances"] = utils.filter_empty_instances(instances) + return dataset_dict diff --git a/data_processing/detectron2/projects/PointSup/point_sup/detection_utils.py b/data_processing/detectron2/projects/PointSup/point_sup/detection_utils.py new file mode 100644 index 0000000..3f95d94 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/detection_utils.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import numpy as np +import torch + +# fmt: off +from detectron2.data.detection_utils import \ + annotations_to_instances as base_annotations_to_instances +from detectron2.data.detection_utils import \ + transform_instance_annotations as base_transform_instance_annotations + +# fmt: on + + +def annotations_to_instances(annos, image_size, sample_points=0): + """ + Create an :class:`Instances` object used by the models, + from instance annotations in the dataset dict. + + Args: + annos (list[dict]): a list of instance annotations in one image, each + element for one instance. + image_size (tuple): height, width + sample_points (int): subsample points at each iteration + + Returns: + Instances: + It will contain fields "gt_boxes", "gt_classes", + "gt_point_coords", "gt_point_labels", if they can be obtained from `annos`. + This is the format that builtin models with point supervision expect. + """ + target = base_annotations_to_instances(annos, image_size) + + assert ("point_coords" in annos[0]) == ("point_labels" in annos[0]) + + if len(annos) and "point_labels" in annos[0]: + point_coords = [] + point_labels = [] + for i, _ in enumerate(annos): + # Already in the image coordinate system + point_coords_wrt_image = np.array(annos[i]["point_coords"]) + point_labels_wrt_image = np.array(annos[i]["point_labels"]) + + if sample_points > 0: + random_indices = np.random.choice( + point_coords_wrt_image.shape[0], + sample_points, + replace=point_coords_wrt_image.shape[0] < sample_points, + ).astype(int) + point_coords_wrt_image = point_coords_wrt_image[random_indices] + point_labels_wrt_image = point_labels_wrt_image[random_indices] + assert point_coords_wrt_image.shape[0] == point_labels_wrt_image.size + + point_coords.append(point_coords_wrt_image) + point_labels.append(point_labels_wrt_image) + + point_coords = torch.stack([torch.from_numpy(x) for x in point_coords]) + point_labels = torch.stack([torch.from_numpy(x) for x in point_labels]) + target.gt_point_coords = point_coords + target.gt_point_labels = point_labels + + return target + + +def transform_instance_annotations( + annotation, transforms, image_size, *, keypoint_hflip_indices=None +): + """ + Apply transforms to box, and point annotations of a single instance. + It will use `transforms.apply_box` for the box, and + `transforms.apply_coords` for points. + Args: + annotation (dict): dict of instance annotations for a single instance. + It will be modified in-place. + transforms (TransformList or list[Transform]): + image_size (tuple): the height, width of the transformed image + keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. + Returns: + dict: + the same input dict with fields "bbox", "point_coords", "point_labels" + transformed according to `transforms`. + The "bbox_mode" field will be set to XYXY_ABS. + """ + annotation = base_transform_instance_annotations( + annotation, transforms, image_size, keypoint_hflip_indices + ) + + assert ("point_coords" in annotation) == ("point_labels" in annotation) + if "point_coords" in annotation and "point_labels" in annotation: + point_coords = annotation["point_coords"] + point_labels = np.array(annotation["point_labels"]).astype(np.float) + point_coords = transforms.apply_coords(point_coords) + + # Set all out-of-boundary points to "unlabeled" + inside = (point_coords >= np.array([0, 0])) & (point_coords <= np.array(image_size[::-1])) + inside = inside.all(axis=1) + point_labels[~inside] = -1 + + annotation["point_coords"] = point_coords + annotation["point_labels"] = point_labels + + return annotation diff --git a/data_processing/detectron2/projects/PointSup/point_sup/mask_head.py b/data_processing/detectron2/projects/PointSup/point_sup/mask_head.py new file mode 100644 index 0000000..81c21f5 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/mask_head.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import numpy as np +from typing import Any, List + +from detectron2.modeling import ROI_MASK_HEAD_REGISTRY +from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference +from detectron2.projects.point_rend import ImplicitPointRendMaskHead +from detectron2.projects.point_rend.point_features import point_sample +from detectron2.projects.point_rend.point_head import roi_mask_point_loss +from detectron2.structures import Instances + +from .point_utils import get_point_coords_from_point_annotation + +__all__ = [ + "ImplicitPointRendPointSupHead", + "MaskRCNNConvUpsamplePointSupHead", +] + + +@ROI_MASK_HEAD_REGISTRY.register() +class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead): + """ + A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). + Predictions are made with a final 1x1 conv layer. + + The difference with `MaskRCNNConvUpsampleHead` is that this head is trained + with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want + to train the model with mask supervision. + """ + + def forward(self, x, instances: List[Instances]) -> Any: + """ + Args: + x: input region feature(s) provided by :class:`ROIHeads`. + instances (list[Instances]): contains the boxes & labels corresponding + to the input features. + Exact format is up to its caller to decide. + Typically, this is the foreground instances in training, with + "proposal_boxes" field and other gt annotations. + In inference, it contains boxes that are already predicted. + Returns: + A dict of losses in training. The predicted "instances" in inference. + """ + x = self.layers(x) + if self.training: + N, C, H, W = x.shape + assert H == W + + proposal_boxes = [x.proposal_boxes for x in instances] + assert N == np.sum(len(x) for x in proposal_boxes) + + if N == 0: + return {"loss_mask": x.sum() * 0} + + # Training with point supervision + point_coords, point_labels = get_point_coords_from_point_annotation(instances) + + mask_logits = point_sample( + x, + point_coords, + align_corners=False, + ) + + return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)} + else: + mask_rcnn_inference(x, instances) + return instances + + +@ROI_MASK_HEAD_REGISTRY.register() +class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead): + def _uniform_sample_train_points(self, instances): + assert self.training + # Please keep in mind that "gt_masks" is not used in this mask head. + point_coords, point_labels = get_point_coords_from_point_annotation(instances) + + return point_coords, point_labels diff --git a/data_processing/detectron2/projects/PointSup/point_sup/point_utils.py b/data_processing/detectron2/projects/PointSup/point_sup/point_utils.py new file mode 100644 index 0000000..eed876e --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/point_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch + +from detectron2.layers import cat + + +def get_point_coords_from_point_annotation(instances): + """ + Load point coords and their corresponding labels from point annotation. + + Args: + instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. These instances are in 1:1 + correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, + ...) associated with each instance are stored in fields. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + point_labels (Tensor): A tensor of shape (N, P) that contains the labels of P + sampled points. `point_labels` takes 3 possible values: + - 0: the point belongs to background + - 1: the point belongs to the object + - -1: the point is ignored during training + """ + point_coords_list = [] + point_labels_list = [] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + point_coords = instances_per_image.gt_point_coords.to(torch.float32) + point_labels = instances_per_image.gt_point_labels.to(torch.float32).clone() + proposal_boxes_per_image = instances_per_image.proposal_boxes.tensor + + # Convert point coordinate system, ground truth points are in image coord. + point_coords_wrt_box = get_point_coords_wrt_box(proposal_boxes_per_image, point_coords) + + # Ignore points that are outside predicted boxes. + point_ignores = ( + (point_coords_wrt_box[:, :, 0] < 0) + | (point_coords_wrt_box[:, :, 0] > 1) + | (point_coords_wrt_box[:, :, 1] < 0) + | (point_coords_wrt_box[:, :, 1] > 1) + ) + point_labels[point_ignores] = -1 + + point_coords_list.append(point_coords_wrt_box) + point_labels_list.append(point_labels) + + return ( + cat(point_coords_list, dim=0), + cat(point_labels_list, dim=0), + ) + + +def get_point_coords_wrt_box(boxes_coords, point_coords): + """ + Convert image-level absolute coordinates to box-normalized [0, 1] x [0, 1] point cooordinates. + Args: + boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. + coordinates. + point_coords (Tensor): A tensor of shape (R, P, 2) that contains + image-normalized coordinates of P sampled points. + Returns: + point_coords_wrt_box (Tensor): A tensor of shape (R, P, 2) that contains + [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. + """ + with torch.no_grad(): + point_coords_wrt_box = point_coords.clone() + point_coords_wrt_box[:, :, 0] -= boxes_coords[:, None, 0] + point_coords_wrt_box[:, :, 1] -= boxes_coords[:, None, 1] + point_coords_wrt_box[:, :, 0] = point_coords_wrt_box[:, :, 0] / ( + boxes_coords[:, None, 2] - boxes_coords[:, None, 0] + ) + point_coords_wrt_box[:, :, 1] = point_coords_wrt_box[:, :, 1] / ( + boxes_coords[:, None, 3] - boxes_coords[:, None, 1] + ) + return point_coords_wrt_box diff --git a/data_processing/detectron2/projects/PointSup/point_sup/register_point_annotations.py b/data_processing/detectron2/projects/PointSup/point_sup/register_point_annotations.py new file mode 100644 index 0000000..32f2bb4 --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/point_sup/register_point_annotations.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.builtin import _get_builtin_metadata +from detectron2.data.datasets.coco import load_coco_json + +logger = logging.getLogger(__name__) + + +# COCO dataset +def register_coco_instances_with_points(name, metadata, json_file, image_root): + """ + Register a dataset in COCO's json annotation format for + instance segmentation with point annotation. + + The point annotation json does not have "segmentation" field, instead, + it has "point_coords" and "point_labels" fields. + + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + metadata (dict): extra metadata associated with this dataset. You can + leave it as an empty dict. + json_file (str): path to the json instance annotation file. + image_root (str or path-like): directory which contains all the images. + """ + assert isinstance(name, str), name + assert isinstance(json_file, (str, os.PathLike)), json_file + assert isinstance(image_root, (str, os.PathLike)), image_root + # 1. register a function which returns dicts + DatasetCatalog.register( + name, lambda: load_coco_json(json_file, image_root, name, ["point_coords", "point_labels"]) + ) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata + ) + + +_PREDEFINED_SPLITS_COCO = {} +_PREDEFINED_SPLITS_COCO["coco"] = { + # point annotations without masks + "coco_2017_train_points_n10_v1_without_masks": ( + "coco/train2017", + "coco/annotations/instances_train2017_n10_v1_without_masks.json", + ), +} + + +def register_all_coco_train_points(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items(): + for key, (image_root, json_file) in splits_per_dataset.items(): + # Assume pre-defined datasets live in `./datasets`. + register_coco_instances_with_points( + key, + _get_builtin_metadata(dataset_name), + os.path.join(root, json_file) if "://" not in json_file else json_file, + os.path.join(root, image_root), + ) + + +# True for open source; +# Internally at fb, we register them elsewhere +if __name__.endswith(".register_point_annotations"): + _root = os.getenv("DETECTRON2_DATASETS", "datasets") + register_all_coco_train_points(_root) diff --git a/data_processing/detectron2/projects/PointSup/tools/prepare_coco_point_annotations_without_masks.py b/data_processing/detectron2/projects/PointSup/tools/prepare_coco_point_annotations_without_masks.py new file mode 100644 index 0000000..e4aee2a --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/tools/prepare_coco_point_annotations_without_masks.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import copy +import json +import numpy as np +import os +import sys +import pycocotools.mask as mask_utils + +from detectron2.utils.env import seed_all_rng +from detectron2.utils.file_io import PathManager + + +def get_point_annotations(input_filename, output_filename, num_points_per_instance): + with PathManager.open(input_filename, "r") as f: + coco_json = json.load(f) + + coco_annos = coco_json.pop("annotations") + coco_points_json = copy.deepcopy(coco_json) + + imgs = {} + for img in coco_json["images"]: + imgs[img["id"]] = img + + new_annos = [] + for ann in coco_annos: + # convert mask + t = imgs[ann["image_id"]] + h, w = t["height"], t["width"] + segm = ann.pop("segmentation") + if type(segm) == list: + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_utils.frPyObjects(segm, h, w) + rle = mask_utils.merge(rles) + elif type(segm["counts"]) == list: + # uncompressed RLE + rle = mask_utils.frPyObjects(segm, h, w) + else: + # rle + rle = segm + mask = mask_utils.decode(rle) + new_ann = copy.deepcopy(ann) + # sample points in image coordinates + box = ann["bbox"] + point_coords_wrt_image = np.random.rand(num_points_per_instance, 2) + point_coords_wrt_image[:, 0] = point_coords_wrt_image[:, 0] * box[2] + point_coords_wrt_image[:, 1] = point_coords_wrt_image[:, 1] * box[3] + point_coords_wrt_image[:, 0] += box[0] + point_coords_wrt_image[:, 1] += box[1] + # round to integer coordinates + point_coords_wrt_image = np.floor(point_coords_wrt_image).astype(int) + # get labels + assert (point_coords_wrt_image >= 0).all(), (point_coords_wrt_image, mask.shape) + assert (point_coords_wrt_image[:, 0] < w).all(), (point_coords_wrt_image, mask.shape) + assert (point_coords_wrt_image[:, 1] < h).all(), (point_coords_wrt_image, mask.shape) + point_labels = mask[point_coords_wrt_image[:, 1], point_coords_wrt_image[:, 0]] + # store new annotations + new_ann["point_coords"] = point_coords_wrt_image.tolist() + new_ann["point_labels"] = point_labels.tolist() + new_annos.append(new_ann) + coco_points_json["annotations"] = new_annos + + with PathManager.open(output_filename, "w") as f: + json.dump(coco_points_json, f) + + print("{} is modified and stored in {}.".format(input_filename, output_filename)) + + +if __name__ == "__main__": + """ + Generate point-based supervision for COCO dataset. + + Usage: + python tools/prepare_coco_point_annotations_without_masks.py \ + NUM_POINTS_PER_INSTANCE NUM_VERSIONS_WITH_DIFFERENT_SEED + + Example to generate point-based COCO dataset with 10 points per instance: + python tools/prepare_coco_point_annotations_without_masks.py 10 + """ + + # Fix random seed + seed_all_rng(12345) + + assert len(sys.argv) >= 2, "Please provide number of points to sample per instance" + dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco/annotations") + num_points_per_instance = int(sys.argv[1]) + if len(sys.argv) == 3: + repeat = int(sys.argv[2]) + else: + repeat = 1 + s = "instances_train2017" + for version in range(repeat): + print( + "Start sampling {} points per instance for annotations {}.".format( + num_points_per_instance, s + ) + ) + get_point_annotations( + os.path.join(dataset_dir, "{}.json".format(s)), + os.path.join( + dataset_dir, + "{}_n{}_v{}_without_masks.json".format(s, num_points_per_instance, version + 1), + ), + num_points_per_instance, + ) diff --git a/data_processing/detectron2/projects/PointSup/train_net.py b/data_processing/detectron2/projects/PointSup/train_net.py new file mode 100644 index 0000000..0fe970a --- /dev/null +++ b/data_processing/detectron2/projects/PointSup/train_net.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Point supervision Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" + +import os + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog, build_detection_train_loader +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, verify_results +from detectron2.projects.point_rend import add_pointrend_config +from detectron2.utils.logger import setup_logger + +from point_sup import PointSupDatasetMapper, add_point_sup_config + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains pre-defined default logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can write your + own training loop. You can use "tools/plain_train_net.py" as an example. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type == "coco": + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format( + dataset_name, evaluator_type + ) + ) + elif len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + @classmethod + def build_train_loader(cls, cfg): + if cfg.INPUT.POINT_SUP: + mapper = PointSupDatasetMapper(cfg, is_train=True) + else: + mapper = None + return build_detection_train_loader(cfg, mapper=mapper) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_pointrend_config(cfg) + add_point_sup_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + # Setup logger for "point_sup" module + setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="point_sup") + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + """ + If you'd like to do anything fancier than the standard training logic, + consider writing your own training loop (see plain_train_net.py) or + subclassing the trainer. + """ + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/README.md b/data_processing/detectron2/projects/README.md new file mode 100644 index 0000000..7fb29af --- /dev/null +++ b/data_processing/detectron2/projects/README.md @@ -0,0 +1,50 @@ + +Here are a few projects that are built on detectron2. +They are examples of how to use detectron2 as a library, to make your projects more +maintainable. + +## Projects by Facebook + +Note that these are research projects, and therefore may not have the same level +of support or stability as detectron2. + ++ [DensePose: Dense Human Pose Estimation In The Wild](DensePose) ++ [Scale-Aware Trident Networks for Object Detection](TridentNet) ++ [TensorMask: A Foundation for Dense Object Segmentation](TensorMask) ++ [Mesh R-CNN](https://github.com/facebookresearch/meshrcnn) ++ [PointRend: Image Segmentation as Rendering](PointRend) ++ [Momentum Contrast for Unsupervised Visual Representation Learning](https://github.com/facebookresearch/moco/tree/master/detection) ++ [DETR: End-to-End Object Detection with Transformers](https://github.com/facebookresearch/detr/tree/master/d2) ++ [Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation](Panoptic-DeepLab) ++ [D2Go (Detectron2Go)](https://github.com/facebookresearch/d2go), an end-to-end production system for training and deployment for mobile platforms. ++ [Pointly-Supervised Instance Segmentation](PointSup) ++ [Unbiased Teacher for Semi-Supervised Object Detection](https://github.com/facebookresearch/unbiased-teacher) ++ [Rethinking "Batch" in BatchNorm](Rethinking-BatchNorm/) ++ [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://github.com/facebookresearch/MaskFormer) ++ [Exploring Plain Vision Transformer Backbones for Object Detection](ViTDet/) ++ [MViTv2: Improved Multiscale Vision Transformers for Classification and Detection](MViTv2/) + + +## External Projects + +External projects in the community that use detectron2: + + + ++ [AdelaiDet](https://github.com/aim-uofa/adet), a detection toolbox including FCOS, BlendMask, etc. ++ [CenterMask](https://github.com/youngwanLEE/centermask2) ++ [Res2Net backbones](https://github.com/Res2Net/Res2Net-detectron2) ++ [VoVNet backbones](https://github.com/youngwanLEE/vovnet-detectron2) ++ [FsDet](https://github.com/ucbdrive/few-shot-object-detection), Few-Shot Object Detection. ++ [Sparse R-CNN](https://github.com/PeizeSun/SparseR-CNN) ++ [BCNet](https://github.com/lkeab/BCNet), a bilayer decoupling instance segmentation method. ++ [DD3D](https://github.com/TRI-ML/dd3d), A fully convolutional 3D detector. ++ [detrex](https://github.com/IDEA-Research/detrex), a detection toolbox for transformer-based detection algorithms including Deformable-DETR, DAB-DETR, DN-DETR, DINO, etc. diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/README.md b/data_processing/detectron2/projects/Rethinking-BatchNorm/README.md new file mode 100644 index 0000000..42c5c68 --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/README.md @@ -0,0 +1,36 @@ +# Rethinking "Batch" in BatchNorm + +We provide configs that reproduce detection experiments in the paper [Rethinking "Batch" in BatchNorm](https://arxiv.org/abs/2105.07576). + +All configs can be trained with: + +``` +../../tools/lazyconfig_train_net.py --config-file configs/X.py --num-gpus 8 +``` + +## Mask R-CNN + +* `mask_rcnn_BNhead.py`, `mask_rcnn_BNhead_batch_stats.py`: + Mask R-CNN with BatchNorm in the head. See Table 3 in the paper. + +* `mask_rcnn_BNhead_shuffle.py`: Mask R-CNN with cross-GPU shuffling of head inputs. + See Figure 9 and Table 6 in the paper. + +* `mask_rcnn_SyncBNhead.py`: Mask R-CNN with cross-GPU SyncBatchNorm in the head. + It matches Table 6 in the paper. + +## RetinaNet + +* `retinanet_SyncBNhead.py`: RetinaNet with SyncBN in head, a straightforward implementation + which matches row 3 of Table 5. + +* `retinanet_SyncBNhead_SharedTraining.py`: RetinaNet with SyncBN in head, normalizing + all 5 feature levels together. Match row 1 of Table 5. + +The script `retinanet-eval-domain-specific.py` evaluates a checkpoint after recomputing +domain-specific statistics. Running it with +``` +./retinanet-eval-domain-specific.py checkpoint.pth +``` +on a model produced by the above two configs, can produce results that match row 4 and +row 2 of Table 5. diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead.py new file mode 100644 index 0000000..336c133 --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead.py @@ -0,0 +1,18 @@ +from detectron2.model_zoo import get_config + +model = get_config("common/models/mask_rcnn_fpn.py").model + +model.backbone.bottom_up.freeze_at = 2 + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "BN" +# 4conv1fc head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] + +dataloader = get_config("common/data/coco.py").dataloader +lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_3x +optimizer = get_config("common/optim.py").SGD +train = get_config("common/train.py").train + +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +train.max_iter = 270000 # 3x for batchsize = 16 diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_batch_stats.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_batch_stats.py new file mode 100644 index 0000000..872e17c --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_batch_stats.py @@ -0,0 +1,20 @@ +from torch.nn import BatchNorm2d +from torch.nn import functional as F + + +class BatchNormBatchStat(BatchNorm2d): + """ + BN that uses batch stat in inference + """ + + def forward(self, input): + if self.training: + return super().forward(input) + return F.batch_norm(input, None, None, self.weight, self.bias, True, 1.0, self.eps) + + +# After training with the base config, it's sufficient to load its model with +# this config only for inference -- because the training-time behavior is identical. +from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = BatchNormBatchStat diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_shuffle.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_shuffle.py new file mode 100644 index 0000000..5117a7d --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_shuffle.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.distributed as dist + +from detectron2.modeling.roi_heads import FastRCNNConvFCHead, MaskRCNNConvUpsampleHead +from detectron2.utils import comm +from fvcore.nn.distributed import differentiable_all_gather + + +def concat_all_gather(input): + bs_int = input.shape[0] + size_list = comm.all_gather(bs_int) + max_size = max(size_list) + max_shape = (max_size,) + input.shape[1:] + + padded_input = input.new_zeros(max_shape) + padded_input[:bs_int] = input + all_inputs = differentiable_all_gather(padded_input) + inputs = [x[:sz] for sz, x in zip(size_list, all_inputs)] + return inputs, size_list + + +def batch_shuffle(x): + # gather from all gpus + batch_size_this = x.shape[0] + all_xs, batch_size_all = concat_all_gather(x) + all_xs_concat = torch.cat(all_xs, dim=0) + total_bs = sum(batch_size_all) + + rank = dist.get_rank() + assert batch_size_all[rank] == batch_size_this + + idx_range = (sum(batch_size_all[:rank]), sum(batch_size_all[: rank + 1])) + + # random shuffle index + idx_shuffle = torch.randperm(total_bs, device=x.device) + # broadcast to all gpus + dist.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + splits = torch.split(idx_shuffle, math.ceil(total_bs / dist.get_world_size())) + if len(splits) > rank: + idx_this = splits[rank] + else: + idx_this = idx_shuffle.new_zeros([0]) + return all_xs_concat[idx_this], idx_unshuffle[idx_range[0] : idx_range[1]] + + +def batch_unshuffle(x, idx_unshuffle): + all_x, _ = concat_all_gather(x) + x_gather = torch.cat(all_x, dim=0) + return x_gather[idx_unshuffle] + + +def wrap_shuffle(module_type, method): + def new_method(self, x): + if self.training: + x, idx = batch_shuffle(x) + x = getattr(module_type, method)(self, x) + if self.training: + x = batch_unshuffle(x, idx) + return x + + return type(module_type.__name__ + "WithShuffle", (module_type,), {method: new_method}) + + +from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train + + +model.roi_heads.box_head._target_ = wrap_shuffle(FastRCNNConvFCHead, "forward") +model.roi_heads.mask_head._target_ = wrap_shuffle(MaskRCNNConvUpsampleHead, "layers") diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_SyncBNhead.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_SyncBNhead.py new file mode 100644 index 0000000..5f05da0 --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/mask_rcnn_SyncBNhead.py @@ -0,0 +1,3 @@ +from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "SyncBN" diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead.py new file mode 100644 index 0000000..222dfdd --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead.py @@ -0,0 +1,19 @@ +from detectron2.model_zoo import get_config +from torch import nn + +model = get_config("common/models/retinanet.py").model +model.backbone.bottom_up.freeze_at = 2 + +# The head will overwrite string "SyncBN" to use domain-specific BN, so we +# provide a class here to use shared BN in training. +model.head.norm = nn.SyncBatchNorm2d + +dataloader = get_config("common/data/coco.py").dataloader +lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_3x +optimizer = get_config("common/optim.py").SGD +train = get_config("common/train.py").train + +optimizer.lr = 0.01 + +train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl" +train.max_iter = 270000 # 3x for batchsize = 16 diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead_SharedTraining.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead_SharedTraining.py new file mode 100644 index 0000000..3f14600 --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/configs/retinanet_SyncBNhead_SharedTraining.py @@ -0,0 +1,32 @@ +from typing import List +import torch +from torch import Tensor, nn + +from detectron2.modeling.meta_arch.retinanet import RetinaNetHead + + +def apply_sequential(inputs, modules): + for mod in modules: + if isinstance(mod, (nn.BatchNorm2d, nn.SyncBatchNorm)): + # for BN layer, normalize all inputs together + shapes = [i.shape for i in inputs] + spatial_sizes = [s[2] * s[3] for s in shapes] + x = [i.flatten(2) for i in inputs] + x = torch.cat(x, dim=2).unsqueeze(3) + x = mod(x).split(spatial_sizes, dim=2) + inputs = [i.view(s) for s, i in zip(shapes, x)] + else: + inputs = [mod(i) for i in inputs] + return inputs + + +class RetinaNetHead_SharedTrainingBN(RetinaNetHead): + def forward(self, features: List[Tensor]): + logits = apply_sequential(features, list(self.cls_subnet) + [self.cls_score]) + bbox_reg = apply_sequential(features, list(self.bbox_subnet) + [self.bbox_pred]) + return logits, bbox_reg + + +from .retinanet_SyncBNhead import model, dataloader, lr_multiplier, optimizer, train + +model.head._target_ = RetinaNetHead_SharedTrainingBN diff --git a/data_processing/detectron2/projects/Rethinking-BatchNorm/retinanet-eval-domain-specific.py b/data_processing/detectron2/projects/Rethinking-BatchNorm/retinanet-eval-domain-specific.py new file mode 100644 index 0000000..49a74ad --- /dev/null +++ b/data_processing/detectron2/projects/Rethinking-BatchNorm/retinanet-eval-domain-specific.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import sys +import torch +from fvcore.nn.precise_bn import update_bn_stats + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.evaluation import inference_on_dataset +from detectron2.layers import CycleBatchNormList +from detectron2.utils.events import EventStorage +from detectron2.utils.logger import setup_logger + +logger = setup_logger() +setup_logger(name="fvcore") + + +if __name__ == "__main__": + checkpoint = sys.argv[1] + cfg = LazyConfig.load_rel("./configs/retinanet_SyncBNhead.py") + model = cfg.model + model.head.norm = lambda c: CycleBatchNormList(len(model.head_in_features), num_features=c) + model = instantiate(model) + model.cuda() + DetectionCheckpointer(model).load(checkpoint) + + cfg.dataloader.train.total_batch_size = 8 + logger.info("Running PreciseBN ...") + with EventStorage(), torch.no_grad(): + update_bn_stats(model, instantiate(cfg.dataloader.train), 500) + + logger.info("Running evaluation ...") + inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) diff --git a/data_processing/detectron2/projects/TensorMask/README.md b/data_processing/detectron2/projects/TensorMask/README.md new file mode 100644 index 0000000..e81307c --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/README.md @@ -0,0 +1,63 @@ + +# TensorMask in Detectron2 +**A Foundation for Dense Object Segmentation** + +Xinlei Chen, Ross Girshick, Kaiming He, Piotr Dollár + +[[`arXiv`](https://arxiv.org/abs/1903.12174)] [[`BibTeX`](#CitingTensorMask)] + +
+ +
+ +In this repository, we release code for TensorMask in Detectron2. +TensorMask is a dense sliding-window instance segmentation framework that, for the first time, achieves results close to the well-developed Mask R-CNN framework -- both qualitatively and quantitatively. It establishes a conceptually complementary direction for object instance segmentation research. + +## Installation +First install Detectron2 following the [documentation](https://detectron2.readthedocs.io/tutorials/install.html) and +[setup the dataset](../../datasets). Then compile the TensorMask-specific op (`swap_align2nat`): +```bash +pip install -e /path/to/detectron2/projects/TensorMask +``` + +## Training + +To train a model, run: +```bash +python /path/to/detectron2/projects/TensorMask/train_net.py --config-file +``` + +For example, to launch TensorMask BiPyramid training (1x schedule) with ResNet-50 backbone on 8 GPUs, +one should execute: +```bash +python /path/to/detectron2/projects/TensorMask/train_net.py --config-file configs/tensormask_R_50_FPN_1x.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly (6x schedule with scale augmentation): +```bash +python /path/to/detectron2/projects/TensorMask/train_net.py --config-file configs/tensormask_R_50_FPN_6x.yaml --eval-only MODEL.WEIGHTS /path/to/model_checkpoint +``` + +# Pretrained Models + +| Backbone | lr sched | AP box | AP mask | download | +| -------- | -------- | -- | --- | -------- | +| R50 | 1x | 37.6 | 32.4 | model \|  metrics | +| R50 | 6x | 41.4 | 35.8 | model \|  metrics | + + +## Citing TensorMask + +If you use TensorMask, please use the following BibTeX entry. + +``` +@InProceedings{chen2019tensormask, + title={Tensormask: A Foundation for Dense Object Segmentation}, + author={Chen, Xinlei and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr}, + journal={The International Conference on Computer Vision (ICCV)}, + year={2019} +} +``` + diff --git a/data_processing/detectron2/projects/TensorMask/configs/Base-TensorMask.yaml b/data_processing/detectron2/projects/TensorMask/configs/Base-TensorMask.yaml new file mode 100644 index 0000000..a724534 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/configs/Base-TensorMask.yaml @@ -0,0 +1,25 @@ +MODEL: + META_ARCHITECTURE: "TensorMask" + MASK_ON: True + BACKBONE: + NAME: "build_retinanet_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[44, 60], [88, 120], [176, 240], [352, 480], [704, 960], [1408, 1920]] + ASPECT_RATIOS: [[1.0]] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + FUSE_TYPE: "avg" + TENSOR_MASK: + ALIGNED_ON: True + BIPYRAMID_ON: True +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +VERSION: 2 diff --git a/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml b/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml new file mode 100644 index 0000000..5d5eee1 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml @@ -0,0 +1,5 @@ +_BASE_: "Base-TensorMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_6x.yaml b/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_6x.yaml new file mode 100644 index 0000000..366a965 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_6x.yaml @@ -0,0 +1,11 @@ +_BASE_: "Base-TensorMask.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (480000, 520000) + MAX_ITER: 540000 +INPUT: + MIN_SIZE_TRAIN_SAMPLING: "range" + MIN_SIZE_TRAIN: (640, 800) diff --git a/data_processing/detectron2/projects/TensorMask/setup.py b/data_processing/detectron2/projects/TensorMask/setup.py new file mode 100644 index 0000000..f6980e0 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/setup.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import glob +import os +from setuptools import find_packages, setup +import torch +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "tensormask", "layers", "csrc") + + main_source = os.path.join(extensions_dir, "vision.cpp") + sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( + os.path.join(extensions_dir, "*.cu") + ) + + sources = [main_source] + sources + + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + + # It's better if pytorch can do this by default .. + CC = os.environ.get("CC", None) + if CC is not None: + extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "tensormask._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="tensormask", + version="0.1", + author="FAIR", + packages=find_packages(exclude=("configs", "tests")), + python_requires=">=3.7", + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/__init__.py b/data_processing/detectron2/projects/TensorMask/tensormask/__init__.py new file mode 100644 index 0000000..eec7978 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import add_tensormask_config +from .arch import TensorMask diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/arch.py b/data_processing/detectron2/projects/TensorMask/tensormask/arch.py new file mode 100644 index 0000000..d395bea --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/arch.py @@ -0,0 +1,913 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import math +from typing import List +import torch +import torch.nn.functional as F +from fvcore.nn import sigmoid_focal_loss_star_jit, smooth_l1_loss +from torch import nn + +from detectron2.layers import ShapeSpec, batched_nms, cat, paste_masks_in_image +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator +from detectron2.modeling.backbone import build_backbone +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.modeling.meta_arch.retinanet import permute_to_N_HWA_K +from detectron2.structures import Boxes, ImageList, Instances + +from tensormask.layers import SwapAlign2Nat + +__all__ = ["TensorMask"] + + +def permute_all_cls_and_box_to_N_HWA_K_and_concat(pred_logits, pred_anchor_deltas, num_classes=80): + """ + Rearrange the tensor layout from the network output, i.e.: + list[Tensor]: #lvl tensors of shape (N, A x K, Hi, Wi) + to per-image predictions, i.e.: + Tensor: of shape (N x sum(Hi x Wi x A), K) + """ + # for each feature level, permute the outputs to make them be in the + # same format as the labels. + pred_logits_flattened = [permute_to_N_HWA_K(x, num_classes) for x in pred_logits] + pred_anchor_deltas_flattened = [permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas] + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + pred_logits = cat(pred_logits_flattened, dim=1).view(-1, num_classes) + pred_anchor_deltas = cat(pred_anchor_deltas_flattened, dim=1).view(-1, 4) + return pred_logits, pred_anchor_deltas + + +def _assignment_rule( + gt_boxes, + anchor_boxes, + unit_lengths, + min_anchor_size, + scale_thresh=2.0, + spatial_thresh=1.0, + uniqueness_on=True, +): + """ + Given two lists of boxes of N ground truth boxes and M anchor boxes, + compute the assignment between the two, following the assignment rules in + https://arxiv.org/abs/1903.12174. + The box order must be (xmin, ymin, xmax, ymax), so please make sure to convert + to BoxMode.XYXY_ABS before calling this function. + + Args: + gt_boxes, anchor_boxes (Boxes): two Boxes. Contains N & M boxes/anchors, respectively. + unit_lengths (Tensor): Contains the unit lengths of M anchor boxes. + min_anchor_size (float): Minimum size of the anchor, in pixels + scale_thresh (float): The `scale` threshold: the maximum size of the anchor + should not be greater than scale_thresh x max(h, w) of + the ground truth box. + spatial_thresh (float): The `spatial` threshold: the l2 distance between the + center of the anchor and the ground truth box should not + be greater than spatial_thresh x u where u is the unit length. + + Returns: + matches (Tensor[int64]): a vector of length M, where matches[i] is a matched + ground-truth index in [0, N) + match_labels (Tensor[int8]): a vector of length M, where pred_labels[i] indicates + whether a prediction is a true or false positive or ignored + """ + gt_boxes, anchor_boxes = gt_boxes.tensor, anchor_boxes.tensor + N = gt_boxes.shape[0] + M = anchor_boxes.shape[0] + if N == 0 or M == 0: + return ( + gt_boxes.new_full((N,), 0, dtype=torch.int64), + gt_boxes.new_full((N,), -1, dtype=torch.int8), + ) + + # Containment rule + lt = torch.min(gt_boxes[:, None, :2], anchor_boxes[:, :2]) # [N,M,2] + rb = torch.max(gt_boxes[:, None, 2:], anchor_boxes[:, 2:]) # [N,M,2] + union = cat([lt, rb], dim=2) # [N,M,4] + + dummy_gt_boxes = torch.zeros_like(gt_boxes) + anchor = dummy_gt_boxes[:, None, :] + anchor_boxes[:, :] # [N,M,4] + + contain_matrix = torch.all(union == anchor, dim=2) # [N,M] + + # Centrality rule, scale + gt_size_lower = torch.max(gt_boxes[:, 2:] - gt_boxes[:, :2], dim=1)[0] # [N] + gt_size_upper = gt_size_lower * scale_thresh # [N] + # Fall back for small objects + gt_size_upper[gt_size_upper < min_anchor_size] = min_anchor_size + # Due to sampling of locations, the anchor sizes are deducted with sampling strides + anchor_size = ( + torch.max(anchor_boxes[:, 2:] - anchor_boxes[:, :2], dim=1)[0] - unit_lengths + ) # [M] + + size_diff_upper = gt_size_upper[:, None] - anchor_size # [N,M] + scale_matrix = size_diff_upper >= 0 # [N,M] + + # Centrality rule, spatial + gt_center = (gt_boxes[:, 2:] + gt_boxes[:, :2]) / 2 # [N,2] + anchor_center = (anchor_boxes[:, 2:] + anchor_boxes[:, :2]) / 2 # [M,2] + offset_center = gt_center[:, None, :] - anchor_center[:, :] # [N,M,2] + offset_center /= unit_lengths[:, None] # [N,M,2] + spatial_square = spatial_thresh * spatial_thresh + spatial_matrix = torch.sum(offset_center * offset_center, dim=2) <= spatial_square + + assign_matrix = (contain_matrix & scale_matrix & spatial_matrix).int() + + # assign_matrix is N (gt) x M (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = assign_matrix.max(dim=0) + match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) + + match_labels[matched_vals == 0] = 0 + match_labels[matched_vals == 1] = 1 + + # find all the elements that match to ground truths multiple times + not_unique_idxs = assign_matrix.sum(dim=0) > 1 + if uniqueness_on: + match_labels[not_unique_idxs] = 0 + else: + match_labels[not_unique_idxs] = -1 + + return matches, match_labels + + +# TODO make the paste_mask function in d2 core support mask list +def _paste_mask_lists_in_image(masks, boxes, image_shape, threshold=0.5): + """ + Paste a list of masks that are of various resolutions (e.g., 28 x 28) into an image. + The location, height, and width for pasting each mask is determined by their + corresponding bounding boxes in boxes. + + Args: + masks (list(Tensor)): A list of Tensor of shape (1, Hmask_i, Wmask_i). + Values are in [0, 1]. The list length, Bimg, is the + number of detected object instances in the image. + boxes (Boxes): A Boxes of length Bimg. boxes.tensor[i] and masks[i] correspond + to the same object instance. + image_shape (tuple): height, width + threshold (float): A threshold in [0, 1] for converting the (soft) masks to + binary masks. + + Returns: + img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the + number of detected object instances and Himage, Wimage are the image width + and height. img_masks[i] is a binary mask for object instance i. + """ + if len(masks) == 0: + return torch.empty((0, 1) + image_shape, dtype=torch.uint8) + + # Loop over masks groups. Each group has the same mask prediction size. + img_masks = [] + ind_masks = [] + mask_sizes = torch.tensor([m.shape[-1] for m in masks]) + unique_sizes = torch.unique(mask_sizes) + for msize in unique_sizes.tolist(): + cur_ind = torch.where(mask_sizes == msize)[0] + ind_masks.append(cur_ind) + + cur_masks = cat([masks[i] for i in cur_ind]) + cur_boxes = boxes[cur_ind] + img_masks.append(paste_masks_in_image(cur_masks, cur_boxes, image_shape, threshold)) + + img_masks = cat(img_masks) + ind_masks = cat(ind_masks) + + img_masks_out = torch.empty_like(img_masks) + img_masks_out[ind_masks, :, :] = img_masks + + return img_masks_out + + +def _postprocess(results, result_mask_info, output_height, output_width, mask_threshold=0.5): + """ + Post-process the output boxes for TensorMask. + The input images are often resized when entering an object detector. + As a result, we often need the outputs of the detector in a different + resolution from its inputs. + + This function will postprocess the raw outputs of TensorMask + to produce outputs according to the desired output resolution. + + Args: + results (Instances): the raw outputs from the detector. + `results.image_size` contains the input image resolution the detector sees. + This object might be modified in-place. Note that it does not contain the field + `pred_masks`, which is provided by another input `result_masks`. + result_mask_info (list[Tensor], Boxes): a pair of two items for mask related results. + The first item is a list of #detection tensors, each is the predicted masks. + The second item is the anchors corresponding to the predicted masks. + output_height, output_width: the desired output resolution. + + Returns: + Instances: the postprocessed output from the model, based on the output resolution + """ + scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) + results = Instances((output_height, output_width), **results.get_fields()) + + output_boxes = results.pred_boxes + output_boxes.tensor[:, 0::2] *= scale_x + output_boxes.tensor[:, 1::2] *= scale_y + output_boxes.clip(results.image_size) + + inds_nonempty = output_boxes.nonempty() + results = results[inds_nonempty] + result_masks, result_anchors = result_mask_info + if result_masks: + result_anchors.tensor[:, 0::2] *= scale_x + result_anchors.tensor[:, 1::2] *= scale_y + result_masks = [x for (i, x) in zip(inds_nonempty.tolist(), result_masks) if i] + results.pred_masks = _paste_mask_lists_in_image( + result_masks, + result_anchors[inds_nonempty], + results.image_size, + threshold=mask_threshold, + ) + return results + + +class TensorMaskAnchorGenerator(DefaultAnchorGenerator): + """ + For a set of image sizes and feature maps, computes a set of anchors for TensorMask. + It also computes the unit lengths and indexes for each anchor box. + """ + + def grid_anchors_with_unit_lengths_and_indexes(self, grid_sizes): + anchors = [] + unit_lengths = [] + indexes = [] + for lvl, (size, stride, base_anchors) in enumerate( + zip(grid_sizes, self.strides, self.cell_anchors) + ): + grid_height, grid_width = size + device = base_anchors.device + shifts_x = torch.arange( + 0, grid_width * stride, step=stride, dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, grid_height * stride, step=stride, dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=2) + # Stack anchors in shapes of (HWA, 4) + cur_anchor = (shifts[:, :, None, :] + base_anchors.view(1, 1, -1, 4)).view(-1, 4) + anchors.append(cur_anchor) + unit_lengths.append( + torch.full((cur_anchor.shape[0],), stride, dtype=torch.float32, device=device) + ) + # create mask indexes using mesh grid + shifts_l = torch.full((1,), lvl, dtype=torch.int64, device=device) + shifts_i = torch.zeros((1,), dtype=torch.int64, device=device) + shifts_h = torch.arange(0, grid_height, dtype=torch.int64, device=device) + shifts_w = torch.arange(0, grid_width, dtype=torch.int64, device=device) + shifts_a = torch.arange(0, base_anchors.shape[0], dtype=torch.int64, device=device) + grids = torch.meshgrid(shifts_l, shifts_i, shifts_h, shifts_w, shifts_a) + + indexes.append(torch.stack(grids, dim=5).view(-1, 5)) + + return anchors, unit_lengths, indexes + + def forward(self, features): + """ + Returns: + list[list[Boxes]]: a list of #image elements. Each is a list of #feature level Boxes. + The Boxes contains anchors of this image on the specific feature level. + list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors. + The tensor contains strides, or unit lengths for the anchors. + list[list[Tensor]]: a list of #image elements. Each is a list of #feature level tensors. + The Tensor contains indexes for the anchors, with the last dimension meaning + (L, N, H, W, A), where L is level, I is image (not set yet), H is height, + W is width, and A is anchor. + """ + num_images = len(features[0]) + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_list, lengths_list, indexes_list = self.grid_anchors_with_unit_lengths_and_indexes( + grid_sizes + ) + + # Convert anchors from Tensor to Boxes + anchors_per_im = [Boxes(x) for x in anchors_list] + + # TODO it can be simplified to not return duplicated information for + # each image, just like detectron2's own AnchorGenerator + anchors = [copy.deepcopy(anchors_per_im) for _ in range(num_images)] + unit_lengths = [copy.deepcopy(lengths_list) for _ in range(num_images)] + indexes = [copy.deepcopy(indexes_list) for _ in range(num_images)] + + return anchors, unit_lengths, indexes + + +@META_ARCH_REGISTRY.register() +class TensorMask(nn.Module): + """ + TensorMask model. Creates FPN backbone, anchors and a head for classification + and box regression. Calculates and applies proper losses to class, box, and + masks. + """ + + def __init__(self, cfg): + super().__init__() + + # fmt: off + self.num_classes = cfg.MODEL.TENSOR_MASK.NUM_CLASSES + self.in_features = cfg.MODEL.TENSOR_MASK.IN_FEATURES + self.anchor_sizes = cfg.MODEL.ANCHOR_GENERATOR.SIZES + self.num_levels = len(cfg.MODEL.ANCHOR_GENERATOR.SIZES) + # Loss parameters: + self.focal_loss_alpha = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_ALPHA + self.focal_loss_gamma = cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_GAMMA + # Inference parameters: + self.score_threshold = cfg.MODEL.TENSOR_MASK.SCORE_THRESH_TEST + self.topk_candidates = cfg.MODEL.TENSOR_MASK.TOPK_CANDIDATES_TEST + self.nms_threshold = cfg.MODEL.TENSOR_MASK.NMS_THRESH_TEST + self.detections_im = cfg.TEST.DETECTIONS_PER_IMAGE + # Mask parameters: + self.mask_on = cfg.MODEL.MASK_ON + self.mask_loss_weight = cfg.MODEL.TENSOR_MASK.MASK_LOSS_WEIGHT + self.mask_pos_weight = torch.tensor(cfg.MODEL.TENSOR_MASK.POSITIVE_WEIGHT, + dtype=torch.float32) + self.bipyramid_on = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON + # fmt: on + + # build the backbone + self.backbone = build_backbone(cfg) + + backbone_shape = self.backbone.output_shape() + feature_shapes = [backbone_shape[f] for f in self.in_features] + feature_strides = [x.stride for x in feature_shapes] + # build anchors + self.anchor_generator = TensorMaskAnchorGenerator(cfg, feature_shapes) + self.num_anchors = self.anchor_generator.num_cell_anchors[0] + anchors_min_level = cfg.MODEL.ANCHOR_GENERATOR.SIZES[0] + self.mask_sizes = [size // feature_strides[0] for size in anchors_min_level] + self.min_anchor_size = min(anchors_min_level) - feature_strides[0] + + # head of the TensorMask + self.head = TensorMaskHead( + cfg, self.num_levels, self.num_anchors, self.mask_sizes, feature_shapes + ) + # box transform + self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS) + self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DetectionTransform` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + image: Tensor, image in (C, H, W) format. + instances: Instances + Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + Returns: + losses (dict[str: Tensor]): mapping from a named loss to a tensor + storing the loss. Used during training only. + """ + images = self.preprocess_image(batched_inputs) + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + + features = self.backbone(images.tensor) + features = [features[f] for f in self.in_features] + # apply the TensorMask head + pred_logits, pred_deltas, pred_masks = self.head(features) + # generate anchors based on features, is it image specific? + anchors, unit_lengths, indexes = self.anchor_generator(features) + + if self.training: + # get ground truths for class labels and box targets, it will label each anchor + gt_class_info, gt_delta_info, gt_mask_info, num_fg = self.get_ground_truth( + anchors, unit_lengths, indexes, gt_instances + ) + # compute the loss + return self.losses( + gt_class_info, + gt_delta_info, + gt_mask_info, + num_fg, + pred_logits, + pred_deltas, + pred_masks, + ) + else: + # do inference to get the output + results = self.inference(pred_logits, pred_deltas, pred_masks, anchors, indexes, images) + processed_results = [] + for results_im, input_im, image_size in zip( + results, batched_inputs, images.image_sizes + ): + height = input_im.get("height", image_size[0]) + width = input_im.get("width", image_size[1]) + # this is to do post-processing with the image size + result_box, result_mask = results_im + r = _postprocess(result_box, result_mask, height, width) + processed_results.append({"instances": r}) + return processed_results + + def losses( + self, + gt_class_info, + gt_delta_info, + gt_mask_info, + num_fg, + pred_logits, + pred_deltas, + pred_masks, + ): + """ + Args: + For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see + :meth:`TensorMask.get_ground_truth`. + For `pred_logits`, `pred_deltas` and `pred_masks`, see + :meth:`TensorMaskHead.forward`. + + Returns: + losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor + storing the loss. Used during training only. The potential dict keys are: + "loss_cls", "loss_box_reg" and "loss_mask". + """ + gt_classes_target, gt_valid_inds = gt_class_info + gt_deltas, gt_fg_inds = gt_delta_info + gt_masks, gt_mask_inds = gt_mask_info + loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device) + + # classification and regression + pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( + pred_logits, pred_deltas, self.num_classes + ) + loss_cls = ( + sigmoid_focal_loss_star_jit( + pred_logits[gt_valid_inds], + gt_classes_target[gt_valid_inds], + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) + / loss_normalizer + ) + + if num_fg == 0: + loss_box_reg = pred_deltas.sum() * 0 + else: + loss_box_reg = ( + smooth_l1_loss(pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum") + / loss_normalizer + ) + losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg} + + # mask prediction + if self.mask_on: + loss_mask = 0 + for lvl in range(self.num_levels): + cur_level_factor = 2**lvl if self.bipyramid_on else 1 + for anc in range(self.num_anchors): + cur_gt_mask_inds = gt_mask_inds[lvl][anc] + if cur_gt_mask_inds is None: + loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0 + else: + cur_mask_size = self.mask_sizes[anc] * cur_level_factor + # TODO maybe there are numerical issues when mask sizes are large + cur_size_divider = torch.tensor( + self.mask_loss_weight / (cur_mask_size**2), + dtype=torch.float32, + device=self.device, + ) + + cur_pred_masks = pred_masks[lvl][anc][ + cur_gt_mask_inds[:, 0], # N + :, # V x U + cur_gt_mask_inds[:, 1], # H + cur_gt_mask_inds[:, 2], # W + ] + + loss_mask += F.binary_cross_entropy_with_logits( + cur_pred_masks.view(-1, cur_mask_size, cur_mask_size), # V, U + gt_masks[lvl][anc].to(dtype=torch.float32), + reduction="sum", + weight=cur_size_divider, + pos_weight=self.mask_pos_weight, + ) + losses["loss_mask"] = loss_mask / loss_normalizer + return losses + + @torch.no_grad() + def get_ground_truth(self, anchors, unit_lengths, indexes, targets): + """ + Args: + anchors (list[list[Boxes]]): a list of N=#image elements. Each is a + list of #feature level Boxes. The Boxes contains anchors of + this image on the specific feature level. + unit_lengths (list[list[Tensor]]): a list of N=#image elements. Each is a + list of #feature level Tensor. The tensor contains unit lengths for anchors of + this image on the specific feature level. + indexes (list[list[Tensor]]): a list of N=#image elements. Each is a + list of #feature level Tensor. The tensor contains the 5D index of + each anchor, the second dimension means (L, N, H, W, A), where L + is level, I is image, H is height, W is width, and A is anchor. + targets (list[Instances]): a list of N `Instances`s. The i-th + `Instances` contains the ground-truth per-instance annotations + for the i-th input image. Specify `targets` during training only. + + Returns: + gt_class_info (Tensor, Tensor): A pair of two tensors for classification. + The first one is an integer tensor of shape (R, #classes) storing ground-truth + labels for each anchor. R is the total number of anchors in the batch. + The second one is an integer tensor of shape (R,), to indicate which + anchors are valid for loss computation, which anchors are not. + gt_delta_info (Tensor, Tensor): A pair of two tensors for boxes. + The first one, of shape (F, 4). F=#foreground anchors. + The last dimension represents ground-truth box2box transform + targets (dx, dy, dw, dh) that map each anchor to its matched ground-truth box. + Only foreground anchors have values in this tensor. Could be `None` if F=0. + The second one, of shape (R,), is an integer tensor indicating which anchors + are foreground ones used for box regression. Could be `None` if F=0. + gt_mask_info (list[list[Tensor]], list[list[Tensor]]): A pair of two lists for masks. + The first one is a list of P=#feature level elements. Each is a + list of A=#anchor tensors. Each tensor contains the ground truth + masks of the same size and for the same feature level. Could be `None`. + The second one is a list of P=#feature level elements. Each is a + list of A=#anchor tensors. Each tensor contains the location of the ground truth + masks of the same size and for the same feature level. The second dimension means + (N, H, W), where N is image, H is height, and W is width. Could be `None`. + num_fg (int): F=#foreground anchors, used later for loss normalization. + """ + gt_classes = [] + gt_deltas = [] + gt_masks = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)] + gt_mask_inds = [[[] for _ in range(self.num_anchors)] for _ in range(self.num_levels)] + + anchors = [Boxes.cat(anchors_i) for anchors_i in anchors] + unit_lengths = [cat(unit_lengths_i) for unit_lengths_i in unit_lengths] + indexes = [cat(indexes_i) for indexes_i in indexes] + + num_fg = 0 + for i, (anchors_im, unit_lengths_im, indexes_im, targets_im) in enumerate( + zip(anchors, unit_lengths, indexes, targets) + ): + # Initialize all + gt_classes_i = torch.full_like( + unit_lengths_im, self.num_classes, dtype=torch.int64, device=self.device + ) + # Ground truth classes + has_gt = len(targets_im) > 0 + if has_gt: + # Compute the pairwise matrix + gt_matched_inds, anchor_labels = _assignment_rule( + targets_im.gt_boxes, anchors_im, unit_lengths_im, self.min_anchor_size + ) + # Find the foreground instances + fg_inds = anchor_labels == 1 + fg_anchors = anchors_im[fg_inds] + num_fg += len(fg_anchors) + # Find the ground truths for foreground instances + gt_fg_matched_inds = gt_matched_inds[fg_inds] + # Assign labels for foreground instances + gt_classes_i[fg_inds] = targets_im.gt_classes[gt_fg_matched_inds] + # Anchors with label -1 are ignored, others are left as negative + gt_classes_i[anchor_labels == -1] = -1 + + # Boxes + # Ground truth box regression, only for foregrounds + matched_gt_boxes = targets_im[gt_fg_matched_inds].gt_boxes + # Compute box regression offsets for foregrounds only + gt_deltas_i = self.box2box_transform.get_deltas( + fg_anchors.tensor, matched_gt_boxes.tensor + ) + gt_deltas.append(gt_deltas_i) + + # Masks + if self.mask_on: + # Compute masks for each level and each anchor + matched_indexes = indexes_im[fg_inds, :] + for lvl in range(self.num_levels): + ids_lvl = matched_indexes[:, 0] == lvl + if torch.any(ids_lvl): + cur_level_factor = 2**lvl if self.bipyramid_on else 1 + for anc in range(self.num_anchors): + ids_lvl_anchor = ids_lvl & (matched_indexes[:, 4] == anc) + if torch.any(ids_lvl_anchor): + gt_masks[lvl][anc].append( + targets_im[ + gt_fg_matched_inds[ids_lvl_anchor] + ].gt_masks.crop_and_resize( + fg_anchors[ids_lvl_anchor].tensor, + self.mask_sizes[anc] * cur_level_factor, + ) + ) + # Select (N, H, W) dimensions + gt_mask_inds_lvl_anc = matched_indexes[ids_lvl_anchor, 1:4] + # Set the image index to the current image + gt_mask_inds_lvl_anc[:, 0] = i + gt_mask_inds[lvl][anc].append(gt_mask_inds_lvl_anc) + gt_classes.append(gt_classes_i) + + # Classes and boxes + gt_classes = cat(gt_classes) + gt_valid_inds = gt_classes >= 0 + gt_fg_inds = gt_valid_inds & (gt_classes < self.num_classes) + gt_classes_target = torch.zeros( + (gt_classes.shape[0], self.num_classes), dtype=torch.float32, device=self.device + ) + gt_classes_target[gt_fg_inds, gt_classes[gt_fg_inds]] = 1 + gt_deltas = cat(gt_deltas) if gt_deltas else None + + # Masks + gt_masks = [[cat(mla) if mla else None for mla in ml] for ml in gt_masks] + gt_mask_inds = [[cat(ila) if ila else None for ila in il] for il in gt_mask_inds] + return ( + (gt_classes_target, gt_valid_inds), + (gt_deltas, gt_fg_inds), + (gt_masks, gt_mask_inds), + num_fg, + ) + + def inference(self, pred_logits, pred_deltas, pred_masks, anchors, indexes, images): + """ + Arguments: + pred_logits, pred_deltas, pred_masks: Same as the output of: + meth:`TensorMaskHead.forward` + anchors, indexes: Same as the input of meth:`TensorMask.get_ground_truth` + images (ImageList): the input images + + Returns: + results (List[Instances]): a list of #images elements. + """ + assert len(anchors) == len(images) + results = [] + + pred_logits = [permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits] + pred_deltas = [permute_to_N_HWA_K(x, 4) for x in pred_deltas] + + pred_logits = cat(pred_logits, dim=1) + pred_deltas = cat(pred_deltas, dim=1) + + for img_idx, (anchors_im, indexes_im) in enumerate(zip(anchors, indexes)): + # Get the size of the current image + image_size = images.image_sizes[img_idx] + + logits_im = pred_logits[img_idx] + deltas_im = pred_deltas[img_idx] + + if self.mask_on: + masks_im = [[mla[img_idx] for mla in ml] for ml in pred_masks] + else: + masks_im = [None] * self.num_levels + results_im = self.inference_single_image( + logits_im, + deltas_im, + masks_im, + Boxes.cat(anchors_im), + cat(indexes_im), + tuple(image_size), + ) + results.append(results_im) + return results + + def inference_single_image( + self, pred_logits, pred_deltas, pred_masks, anchors, indexes, image_size + ): + """ + Single-image inference. Return bounding-box detection results by thresholding + on scores and applying non-maximum suppression (NMS). + + Arguments: + pred_logits (list[Tensor]): list of #feature levels. Each entry contains + tensor of size (AxHxW, K) + pred_deltas (list[Tensor]): Same shape as 'pred_logits' except that K becomes 4. + pred_masks (list[list[Tensor]]): List of #feature levels, each is a list of #anchors. + Each entry contains tensor of size (M_i*M_i, H, W). `None` if mask_on=False. + anchors (list[Boxes]): list of #feature levels. Each entry contains + a Boxes object, which contains all the anchors for that + image in that feature level. + image_size (tuple(H, W)): a tuple of the image height and width. + + Returns: + Same as `inference`, but for only one image. + """ + pred_logits = pred_logits.flatten().sigmoid_() + # We get top locations across all levels to accelerate the inference speed, + # which does not seem to affect the accuracy. + # First select values above the threshold + logits_top_idxs = torch.where(pred_logits > self.score_threshold)[0] + # Then get the top values + num_topk = min(self.topk_candidates, logits_top_idxs.shape[0]) + pred_prob, topk_idxs = pred_logits[logits_top_idxs].sort(descending=True) + # Keep top k scoring values + pred_prob = pred_prob[:num_topk] + # Keep top k values + top_idxs = logits_top_idxs[topk_idxs[:num_topk]] + + # class index + cls_idxs = top_idxs % self.num_classes + # HWA index + top_idxs //= self.num_classes + # predict boxes + pred_boxes = self.box2box_transform.apply_deltas( + pred_deltas[top_idxs], anchors[top_idxs].tensor + ) + # apply nms + keep = batched_nms(pred_boxes, pred_prob, cls_idxs, self.nms_threshold) + # pick the top ones + keep = keep[: self.detections_im] + + results = Instances(image_size) + results.pred_boxes = Boxes(pred_boxes[keep]) + results.scores = pred_prob[keep] + results.pred_classes = cls_idxs[keep] + + # deal with masks + result_masks, result_anchors = [], None + if self.mask_on: + # index and anchors, useful for masks + top_indexes = indexes[top_idxs] + top_anchors = anchors[top_idxs] + result_indexes = top_indexes[keep] + result_anchors = top_anchors[keep] + # Get masks and do sigmoid + for lvl, _, h, w, anc in result_indexes.tolist(): + cur_size = self.mask_sizes[anc] * (2**lvl if self.bipyramid_on else 1) + result_masks.append( + torch.sigmoid(pred_masks[lvl][anc][:, h, w].view(1, cur_size, cur_size)) + ) + + return results, (result_masks, result_anchors) + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.backbone.size_divisibility) + return images + + +class TensorMaskHead(nn.Module): + def __init__(self, cfg, num_levels, num_anchors, mask_sizes, input_shape: List[ShapeSpec]): + """ + TensorMask head. + """ + super().__init__() + # fmt: off + self.in_features = cfg.MODEL.TENSOR_MASK.IN_FEATURES + in_channels = input_shape[0].channels + num_classes = cfg.MODEL.TENSOR_MASK.NUM_CLASSES + cls_channels = cfg.MODEL.TENSOR_MASK.CLS_CHANNELS + num_convs = cfg.MODEL.TENSOR_MASK.NUM_CONVS + # box parameters + bbox_channels = cfg.MODEL.TENSOR_MASK.BBOX_CHANNELS + # mask parameters + self.mask_on = cfg.MODEL.MASK_ON + self.mask_sizes = mask_sizes + mask_channels = cfg.MODEL.TENSOR_MASK.MASK_CHANNELS + self.align_on = cfg.MODEL.TENSOR_MASK.ALIGNED_ON + self.bipyramid_on = cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON + # fmt: on + + # class subnet + cls_subnet = [] + cur_channels = in_channels + for _ in range(num_convs): + cls_subnet.append( + nn.Conv2d(cur_channels, cls_channels, kernel_size=3, stride=1, padding=1) + ) + cur_channels = cls_channels + cls_subnet.append(nn.ReLU()) + + self.cls_subnet = nn.Sequential(*cls_subnet) + self.cls_score = nn.Conv2d( + cur_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1 + ) + modules_list = [self.cls_subnet, self.cls_score] + + # box subnet + bbox_subnet = [] + cur_channels = in_channels + for _ in range(num_convs): + bbox_subnet.append( + nn.Conv2d(cur_channels, bbox_channels, kernel_size=3, stride=1, padding=1) + ) + cur_channels = bbox_channels + bbox_subnet.append(nn.ReLU()) + + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.bbox_pred = nn.Conv2d( + cur_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1 + ) + modules_list.extend([self.bbox_subnet, self.bbox_pred]) + + # mask subnet + if self.mask_on: + mask_subnet = [] + cur_channels = in_channels + for _ in range(num_convs): + mask_subnet.append( + nn.Conv2d(cur_channels, mask_channels, kernel_size=3, stride=1, padding=1) + ) + cur_channels = mask_channels + mask_subnet.append(nn.ReLU()) + + self.mask_subnet = nn.Sequential(*mask_subnet) + modules_list.append(self.mask_subnet) + for mask_size in self.mask_sizes: + cur_mask_module = "mask_pred_%02d" % mask_size + self.add_module( + cur_mask_module, + nn.Conv2d( + cur_channels, mask_size * mask_size, kernel_size=1, stride=1, padding=0 + ), + ) + modules_list.append(getattr(self, cur_mask_module)) + if self.align_on: + if self.bipyramid_on: + for lvl in range(num_levels): + cur_mask_module = "align2nat_%02d" % lvl + lambda_val = 2**lvl + setattr(self, cur_mask_module, SwapAlign2Nat(lambda_val)) + # Also the fusing layer, stay at the same channel size + mask_fuse = [ + nn.Conv2d(cur_channels, cur_channels, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + ] + self.mask_fuse = nn.Sequential(*mask_fuse) + modules_list.append(self.mask_fuse) + else: + self.align2nat = SwapAlign2Nat(1) + + # Initialization + for modules in modules_list: + for layer in modules.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, mean=0, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + # Use prior in model initialization to improve stability + bias_value = -(math.log((1 - 0.01) / 0.01)) + torch.nn.init.constant_(self.cls_score.bias, bias_value) + + def forward(self, features): + """ + Arguments: + features (list[Tensor]): FPN feature map tensors in high to low resolution. + Each tensor in the list correspond to different feature levels. + + Returns: + pred_logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi). + The tensor predicts the classification probability + at each spatial position for each of the A anchors and K object + classes. + pred_deltas (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi). + The tensor predicts 4-vector (dx,dy,dw,dh) box + regression values for every anchor. These values are the + relative offset between the anchor and the ground truth box. + pred_masks (list(list[Tensor])): #lvl list of tensors, each is a list of + A tensors of shape (N, M_{i,a}, Hi, Wi). + The tensor predicts a dense set of M_ixM_i masks at every location. + """ + pred_logits = [self.cls_score(self.cls_subnet(x)) for x in features] + pred_deltas = [self.bbox_pred(self.bbox_subnet(x)) for x in features] + + pred_masks = None + if self.mask_on: + mask_feats = [self.mask_subnet(x) for x in features] + + if self.bipyramid_on: + mask_feat_high_res = mask_feats[0] + H, W = mask_feat_high_res.shape[-2:] + mask_feats_up = [] + for lvl, mask_feat in enumerate(mask_feats): + lambda_val = 2.0**lvl + mask_feat_up = mask_feat + if lvl > 0: + mask_feat_up = F.interpolate( + mask_feat, scale_factor=lambda_val, mode="bilinear", align_corners=False + ) + mask_feats_up.append( + self.mask_fuse(mask_feat_up[:, :, :H, :W] + mask_feat_high_res) + ) + mask_feats = mask_feats_up + + pred_masks = [] + for lvl, mask_feat in enumerate(mask_feats): + cur_masks = [] + for mask_size in self.mask_sizes: + cur_mask_module = getattr(self, "mask_pred_%02d" % mask_size) + cur_mask = cur_mask_module(mask_feat) + if self.align_on: + if self.bipyramid_on: + cur_mask_module = getattr(self, "align2nat_%02d" % lvl) + cur_mask = cur_mask_module(cur_mask) + else: + cur_mask = self.align2nat(cur_mask) + cur_masks.append(cur_mask) + pred_masks.append(cur_masks) + return pred_logits, pred_deltas, pred_masks diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/config.py b/data_processing/detectron2/projects/TensorMask/tensormask/config.py new file mode 100644 index 0000000..cf62d7a --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/config.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.config import CfgNode as CN + + +def add_tensormask_config(cfg): + """ + Add config for TensorMask. + """ + cfg.MODEL.TENSOR_MASK = CN() + + # Anchor parameters + cfg.MODEL.TENSOR_MASK.IN_FEATURES = ["p2", "p3", "p4", "p5", "p6", "p7"] + + # Convolutions to use in the towers + cfg.MODEL.TENSOR_MASK.NUM_CONVS = 4 + + # Number of foreground classes. + cfg.MODEL.TENSOR_MASK.NUM_CLASSES = 80 + # Channel size for the classification tower + cfg.MODEL.TENSOR_MASK.CLS_CHANNELS = 256 + + cfg.MODEL.TENSOR_MASK.SCORE_THRESH_TEST = 0.05 + # Only the top (1000 * #levels) candidate boxes across all levels are + # considered jointly during test (to improve speed) + cfg.MODEL.TENSOR_MASK.TOPK_CANDIDATES_TEST = 6000 + cfg.MODEL.TENSOR_MASK.NMS_THRESH_TEST = 0.5 + + # Box parameters + # Channel size for the box tower + cfg.MODEL.TENSOR_MASK.BBOX_CHANNELS = 128 + # Weights on (dx, dy, dw, dh) + cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS = (1.5, 1.5, 0.75, 0.75) + + # Loss parameters + cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_GAMMA = 3.0 + cfg.MODEL.TENSOR_MASK.FOCAL_LOSS_ALPHA = 0.3 + + # Mask parameters + # Channel size for the mask tower + cfg.MODEL.TENSOR_MASK.MASK_CHANNELS = 128 + # Mask loss weight + cfg.MODEL.TENSOR_MASK.MASK_LOSS_WEIGHT = 2.0 + # weight on positive pixels within the mask + cfg.MODEL.TENSOR_MASK.POSITIVE_WEIGHT = 1.5 + # Whether to predict in the aligned representation + cfg.MODEL.TENSOR_MASK.ALIGNED_ON = False + # Whether to use the bipyramid architecture + cfg.MODEL.TENSOR_MASK.BIPYRAMID_ON = False diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/layers/__init__.py b/data_processing/detectron2/projects/TensorMask/tensormask/layers/__init__.py new file mode 100644 index 0000000..8b8e178 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/layers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .swap_align2nat import SwapAlign2Nat, swap_align2nat + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat.h b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat.h new file mode 100644 index 0000000..75c2178 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat.h @@ -0,0 +1,54 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace tensormask { + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor SwapAlign2Nat_forward_cuda( + const at::Tensor& X, + const int lambda_val, + const float pad_val); + +at::Tensor SwapAlign2Nat_backward_cuda( + const at::Tensor& gY, + const int lambda_val, + const int batch_size, + const int channel, + const int height, + const int width); +#endif + +inline at::Tensor SwapAlign2Nat_forward( + const at::Tensor& X, + const int lambda_val, + const float pad_val) { + if (X.type().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return SwapAlign2Nat_forward_cuda(X, lambda_val, pad_val); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +inline at::Tensor SwapAlign2Nat_backward( + const at::Tensor& gY, + const int lambda_val, + const int batch_size, + const int channel, + const int height, + const int width) { + if (gY.type().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return SwapAlign2Nat_backward_cuda( + gY, lambda_val, batch_size, channel, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +} // namespace tensormask diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat_cuda.cu b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat_cuda.cu new file mode 100644 index 0000000..1398d70 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/SwapAlign2Nat/SwapAlign2Nat_cuda.cu @@ -0,0 +1,526 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +template +__device__ inline T get_pixel_val( + const T* tensor, + const int idx, + const int H, + const int W, + const int y, + const int x, + const int V, + const int U, + const int v, + const int u, + const T pad_val) { + if ((y < 0) || (y >= H) || (x < 0) || (x >= W) || (v < 0) || (v >= V) || + (u < 0) || (u >= U)) { + return pad_val; + } else { + return tensor[(((idx * V + v) * U + u) * H + y) * W + x]; + } +} + +template +__device__ inline void add_pixel_val( + T* tensor, + const T val, + const int idx, + const int H, + const int W, + const int y, + const int x, + const int V, + const int U, + const int v, + const int u) { + if ((val == 0.) || (y < 0) || (y >= H) || (x < 0) || (x >= W) || (v < 0) || + (v >= V) || (u < 0) || (u >= U)) { + return; + } else { + atomicAdd(tensor + ((((idx * V + v) * U + u) * H + y) * W + x), val); + } +} + +template +__global__ void SwapAlign2NatForwardFeat( + const int nthreads, + const T* bottom_data, + const int Vout, + const int Uout, + const float hVout, + const float hUout, + const int Vin, + const int Uin, + const float lambda, + const int Hin, + const int Win, + const int Hout, + const int Wout, + const T pad_val, + T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int idx = index; + const int x = idx % Wout; + idx /= Wout; + const int y = idx % Hout; + idx /= Hout; + const int u = idx % Uout; + idx /= Uout; + const int v = idx % Vout; + idx /= Vout; + + const float ox = x * lambda + u - hUout + 0.5; + const int xf = static_cast(floor(ox)); + const int xc = static_cast(ceil(ox)); + const float xwc = ox - xf; + const float xwf = 1. - xwc; + + const float oy = y * lambda + v - hVout + 0.5; + const int yf = static_cast(floor(oy)); + const int yc = static_cast(ceil(oy)); + const float ywc = oy - yf; + const float ywf = 1. - ywc; + + const float ou = (u + 0.5) / lambda - 0.5; + const int uf = static_cast(floor(ou)); + const int uc = static_cast(ceil(ou)); + const float uwc = ou - uf; + const float uwf = 1. - uwc; + + const float ov = (v + 0.5) / lambda - 0.5; + const int vf = static_cast(floor(ov)); + const int vc = static_cast(ceil(ov)); + const float vwc = ov - vf; + const float vwf = 1. - vwc; + + T val = ywf * xwf * vwf * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vf, uf, pad_val) + + ywf * xwf * vwf * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vf, uc, pad_val) + + ywf * xwf * vwc * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vc, uf, pad_val) + + ywf * xwf * vwc * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xf, Vin, Uin, vc, uc, pad_val) + + ywf * xwc * vwf * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vf, uf, pad_val) + + ywf * xwc * vwf * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vf, uc, pad_val) + + ywf * xwc * vwc * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vc, uf, pad_val) + + ywf * xwc * vwc * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yf, xc, Vin, Uin, vc, uc, pad_val) + + ywc * xwf * vwf * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vf, uf, pad_val) + + ywc * xwf * vwf * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vf, uc, pad_val) + + ywc * xwf * vwc * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vc, uf, pad_val) + + ywc * xwf * vwc * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xf, Vin, Uin, vc, uc, pad_val) + + ywc * xwc * vwf * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vf, uf, pad_val) + + ywc * xwc * vwf * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vf, uc, pad_val) + + ywc * xwc * vwc * uwf * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vc, uf, pad_val) + + ywc * xwc * vwc * uwc * + get_pixel_val( + bottom_data, idx, Hin, Win, yc, xc, Vin, Uin, vc, uc, pad_val); + + top_data[index] = val; + } +} + +template +__global__ void SwapAlign2NatBackwardFeat( + const int nthreads, + const T* top_diff, + const int Vout, + const int Uout, + const float hVout, + const float hUout, + const int Vin, + const int Uin, + const float lambda, + const int Hin, + const int Win, + const int Hout, + const int Wout, + T* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int idx = index; + const int x = idx % Wout; + idx /= Wout; + const int y = idx % Hout; + idx /= Hout; + const int u = idx % Uout; + idx /= Uout; + const int v = idx % Vout; + idx /= Vout; + + const float ox = x * lambda + u - hUout + 0.5; + const int xf = static_cast(floor(ox)); + const int xc = static_cast(ceil(ox)); + const float xwc = ox - xf; + const float xwf = 1. - xwc; + + const float oy = y * lambda + v - hVout + 0.5; + const int yf = static_cast(floor(oy)); + const int yc = static_cast(ceil(oy)); + const float ywc = oy - yf; + const float ywf = 1. - ywc; + + const float ou = (u + 0.5) / lambda - 0.5; + const int uf = static_cast(floor(ou)); + const int uc = static_cast(ceil(ou)); + const float uwc = ou - uf; + const float uwf = 1. - uwc; + + const float ov = (v + 0.5) / lambda - 0.5; + const int vf = static_cast(floor(ov)); + const int vc = static_cast(ceil(ov)); + const float vwc = ov - vf; + const float vwf = 1. - vwc; + + const T grad = top_diff[index]; + + add_pixel_val( + bottom_diff, + ywf * xwf * vwf * uwf * grad, + idx, + Hin, + Win, + yf, + xf, + Vin, + Uin, + vf, + uf); + add_pixel_val( + bottom_diff, + ywf * xwf * vwf * uwc * grad, + idx, + Hin, + Win, + yf, + xf, + Vin, + Uin, + vf, + uc); + add_pixel_val( + bottom_diff, + ywf * xwf * vwc * uwf * grad, + idx, + Hin, + Win, + yf, + xf, + Vin, + Uin, + vc, + uf); + add_pixel_val( + bottom_diff, + ywf * xwf * vwc * uwc * grad, + idx, + Hin, + Win, + yf, + xf, + Vin, + Uin, + vc, + uc); + add_pixel_val( + bottom_diff, + ywf * xwc * vwf * uwf * grad, + idx, + Hin, + Win, + yf, + xc, + Vin, + Uin, + vf, + uf); + add_pixel_val( + bottom_diff, + ywf * xwc * vwf * uwc * grad, + idx, + Hin, + Win, + yf, + xc, + Vin, + Uin, + vf, + uc); + add_pixel_val( + bottom_diff, + ywf * xwc * vwc * uwf * grad, + idx, + Hin, + Win, + yf, + xc, + Vin, + Uin, + vc, + uf); + add_pixel_val( + bottom_diff, + ywf * xwc * vwc * uwc * grad, + idx, + Hin, + Win, + yf, + xc, + Vin, + Uin, + vc, + uc); + add_pixel_val( + bottom_diff, + ywc * xwf * vwf * uwf * grad, + idx, + Hin, + Win, + yc, + xf, + Vin, + Uin, + vf, + uf); + add_pixel_val( + bottom_diff, + ywc * xwf * vwf * uwc * grad, + idx, + Hin, + Win, + yc, + xf, + Vin, + Uin, + vf, + uc); + add_pixel_val( + bottom_diff, + ywc * xwf * vwc * uwf * grad, + idx, + Hin, + Win, + yc, + xf, + Vin, + Uin, + vc, + uf); + add_pixel_val( + bottom_diff, + ywc * xwf * vwc * uwc * grad, + idx, + Hin, + Win, + yc, + xf, + Vin, + Uin, + vc, + uc); + add_pixel_val( + bottom_diff, + ywc * xwc * vwf * uwf * grad, + idx, + Hin, + Win, + yc, + xc, + Vin, + Uin, + vf, + uf); + add_pixel_val( + bottom_diff, + ywc * xwc * vwf * uwc * grad, + idx, + Hin, + Win, + yc, + xc, + Vin, + Uin, + vf, + uc); + add_pixel_val( + bottom_diff, + ywc * xwc * vwc * uwf * grad, + idx, + Hin, + Win, + yc, + xc, + Vin, + Uin, + vc, + uf); + add_pixel_val( + bottom_diff, + ywc * xwc * vwc * uwc * grad, + idx, + Hin, + Win, + yc, + xc, + Vin, + Uin, + vc, + uc); + } +} + +namespace tensormask { + +at::Tensor SwapAlign2Nat_forward_cuda( + const at::Tensor& X, + const int lambda_val, + const float pad_val) { + AT_ASSERTM(X.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(X.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(lambda_val >= 1, "lambda should be greater or equal to 1"); + const int N = X.size(0); + const int C = X.size(1); + const int Vin = static_cast(sqrt(static_cast(C))); + const int Uin = C / Vin; + AT_ASSERTM( + C == Vin * Uin && Vin == Uin, "#channels should be a square number"); + const int Vout = lambda_val * Vin; + const int Uout = lambda_val * Uin; + const int Hin = X.size(2); + const int Win = X.size(3); + const float lambda = static_cast(lambda_val); + const int Hout = static_cast(ceil(Hin / lambda)); + const int Wout = static_cast(ceil(Win / lambda)); + const float hVout = Vout / 2.; + const float hUout = Uout / 2.; + + at::cuda::CUDAGuard device_guard(X.device()); + + at::Tensor Y = at::empty({N, Vout * Uout, Hout, Wout}, X.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(Y.numel(), 512L), 4096L)); + dim3 block(512); + + if (Y.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return Y; + } + + auto X_ = X.contiguous(); + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "SwapAlign2Nat_forward", [&] { + SwapAlign2NatForwardFeat<<>>( + Y.numel(), + X_.data_ptr(), + Vout, + Uout, + hVout, + hUout, + Vin, + Uin, + lambda, + Hin, + Win, + Hout, + Wout, + pad_val, + Y.data_ptr()); + }); + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + return Y; +} + +at::Tensor SwapAlign2Nat_backward_cuda( + const at::Tensor& gY, + const int lambda_val, + const int batch_size, + const int channel, + const int height, + const int width) { + AT_ASSERTM(gY.device().is_cuda(), "input gradient must be a CUDA tensor"); + AT_ASSERTM(gY.ndimension() == 4, "input gradient must be a 4D tensor"); + AT_ASSERTM(lambda_val >= 1, "lambda should be greater or equal to 1"); + const int Vin = static_cast(sqrt(static_cast(channel))); + const int Uin = channel / Vin; + const int Vout = lambda_val * Vin; + const int Uout = lambda_val * Uin; + const float hVout = Vout / 2.; + const float hUout = Uout / 2.; + const int Hout = gY.size(2); + const int Wout = gY.size(3); + + at::cuda::CUDAGuard device_guard(gY.device()); + + at::Tensor gX = at::zeros({batch_size, channel, height, width}, gY.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(gY.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (gY.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return gX; + } + + auto gY_ = gY.contiguous(); + AT_DISPATCH_FLOATING_TYPES(gY.scalar_type(), "SwapAlign2Nat_backward", [&] { + SwapAlign2NatBackwardFeat<<>>( + gY.numel(), + gY_.data_ptr(), + Vout, + Uout, + hVout, + hUout, + Vin, + Uin, + static_cast(lambda_val), + height, + width, + Hout, + Wout, + gX.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return gX; +} + +} // namespace tensormask diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/vision.cpp b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/vision.cpp new file mode 100644 index 0000000..ed1ed0b --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/layers/csrc/vision.cpp @@ -0,0 +1,19 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include +#include "SwapAlign2Nat/SwapAlign2Nat.h" + +namespace tensormask { + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "swap_align2nat_forward", + &SwapAlign2Nat_forward, + "SwapAlign2Nat_forward"); + m.def( + "swap_align2nat_backward", + &SwapAlign2Nat_backward, + "SwapAlign2Nat_backward"); +} + +} // namespace tensormask diff --git a/data_processing/detectron2/projects/TensorMask/tensormask/layers/swap_align2nat.py b/data_processing/detectron2/projects/TensorMask/tensormask/layers/swap_align2nat.py new file mode 100644 index 0000000..2b5e450 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tensormask/layers/swap_align2nat.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from tensormask import _C + + +class _SwapAlign2Nat(Function): + @staticmethod + def forward(ctx, X, lambda_val, pad_val): + ctx.lambda_val = lambda_val + ctx.input_shape = X.size() + + Y = _C.swap_align2nat_forward(X, lambda_val, pad_val) + return Y + + @staticmethod + @once_differentiable + def backward(ctx, gY): + lambda_val = ctx.lambda_val + bs, ch, h, w = ctx.input_shape + + gX = _C.swap_align2nat_backward(gY, lambda_val, bs, ch, h, w) + + return gX, None, None + + +swap_align2nat = _SwapAlign2Nat.apply + + +class SwapAlign2Nat(nn.Module): + """ + The op `SwapAlign2Nat` described in https://arxiv.org/abs/1903.12174. + Given an input tensor that predicts masks of shape (N, C=VxU, H, W), + apply the op, it will return masks of shape (N, V'xU', H', W') where + the unit lengths of (V, U) and (H, W) are swapped, and the mask representation + is transformed from aligned to natural. + Args: + lambda_val (int): the relative unit length ratio between (V, U) and (H, W), + as we always have larger unit lengths for (V, U) than (H, W), + lambda_val is always >= 1. + pad_val (float): padding value for the values falling outside of the input + tensor, default set to -6 as sigmoid(-6) is ~0, indicating + that is no masks outside of the tensor. + """ + + def __init__(self, lambda_val, pad_val=-6.0): + super(SwapAlign2Nat, self).__init__() + self.lambda_val = lambda_val + self.pad_val = pad_val + + def forward(self, X): + return swap_align2nat(X, self.lambda_val, self.pad_val) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "lambda_val=" + str(self.lambda_val) + tmpstr += ", pad_val=" + str(self.pad_val) + tmpstr += ")" + return tmpstr diff --git a/data_processing/detectron2/projects/TensorMask/tests/__init__.py b/data_processing/detectron2/projects/TensorMask/tests/__init__.py new file mode 100644 index 0000000..9020c2d --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/data_processing/detectron2/projects/TensorMask/tests/test_swap_align2nat.py b/data_processing/detectron2/projects/TensorMask/tests/test_swap_align2nat.py new file mode 100644 index 0000000..d9ee273 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/tests/test_swap_align2nat.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +import torch +from torch.autograd import gradcheck + +from tensormask.layers.swap_align2nat import SwapAlign2Nat + + +class SwapAlign2NatTest(unittest.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_swap_align2nat_gradcheck_cuda(self): + dtype = torch.float64 + device = torch.device("cuda") + m = SwapAlign2Nat(2).to(dtype=dtype, device=device) + x = torch.rand(2, 4, 10, 10, dtype=dtype, device=device, requires_grad=True) + + self.assertTrue(gradcheck(m, x), "gradcheck failed for SwapAlign2Nat CUDA") + + def _swap_align2nat(self, tensor, lambda_val): + """ + The basic setup for testing Swap_Align + """ + op = SwapAlign2Nat(lambda_val, pad_val=0.0) + input = torch.from_numpy(tensor[None, :, :, :].astype("float32")) + output = op.forward(input.cuda()).cpu().numpy() + return output[0] + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/projects/TensorMask/train_net.py b/data_processing/detectron2/projects/TensorMask/train_net.py new file mode 100644 index 0000000..dc77a64 --- /dev/null +++ b/data_processing/detectron2/projects/TensorMask/train_net.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +TensorMask Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" + +import os + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import COCOEvaluator, verify_results + +from tensormask import add_tensormask_config + + +class Trainer(DefaultTrainer): + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + return COCOEvaluator(dataset_name, output_dir=output_folder) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_tensormask_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/TridentNet/README.md b/data_processing/detectron2/projects/TridentNet/README.md new file mode 100644 index 0000000..4b7a901 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/README.md @@ -0,0 +1,60 @@ + +# TridentNet in Detectron2 +**Scale-Aware Trident Networks for Object Detection** + +Yanghao Li\*, Yuntao Chen\*, Naiyan Wang, Zhaoxiang Zhang + +[[`TridentNet`](https://github.com/TuSimple/simpledet/tree/master/models/tridentnet)] [[`arXiv`](https://arxiv.org/abs/1901.01892)] [[`BibTeX`](#CitingTridentNet)] + +
+ +
+ +In this repository, we implement TridentNet-Fast in Detectron2. +Trident Network (TridentNet) aims to generate scale-specific feature maps with a uniform representational power. We construct a parallel multi-branch architecture in which each branch shares the same transformation parameters but with different receptive fields. TridentNet-Fast is a fast approximation version of TridentNet that could achieve significant improvements without any additional parameters and computational cost. + +## Training + +To train a model, run +```bash +python /path/to/detectron2/projects/TridentNet/train_net.py --config-file +``` + +For example, to launch end-to-end TridentNet training with ResNet-50 backbone on 8 GPUs, +one should execute: +```bash +python /path/to/detectron2/projects/TridentNet/train_net.py --config-file configs/tridentnet_fast_R_50_C4_1x.yaml --num-gpus 8 +``` + +## Evaluation + +Model evaluation can be done similarly: +```bash +python /path/to/detectron2/projects/TridentNet/train_net.py --config-file configs/tridentnet_fast_R_50_C4_1x.yaml --eval-only MODEL.WEIGHTS model.pth +``` + +## Results on MS-COCO in Detectron2 + +|Model|Backbone|Head|lr sched|AP|AP50|AP75|APs|APm|APl|download| +|-----|--------|----|--------|--|----|----|---|---|---|--------| +|Faster|R50-C4|C5-512ROI|1X|35.7|56.1|38.0|19.2|40.9|48.7|model \| metrics| +|TridentFast|R50-C4|C5-128ROI|1X|38.0|58.1|40.8|19.5|42.2|54.6|model \| metrics| +|Faster|R50-C4|C5-512ROI|3X|38.4|58.7|41.3|20.7|42.7|53.1|model \| metrics| +|TridentFast|R50-C4|C5-128ROI|3X|40.6|60.8|43.6|23.4|44.7|57.1|model \| metrics| +|Faster|R101-C4|C5-512ROI|3X|41.1|61.4|44.0|22.2|45.5|55.9|model \| metrics| +|TridentFast|R101-C4|C5-128ROI|3X|43.6|63.4|47.0|24.3|47.8|60.0|model \| metrics| + + +## Citing TridentNet + +If you use TridentNet, please use the following BibTeX entry. + +``` +@InProceedings{li2019scale, + title={Scale-Aware Trident Networks for Object Detection}, + author={Li, Yanghao and Chen, Yuntao and Wang, Naiyan and Zhang, Zhaoxiang}, + journal={The International Conference on Computer Vision (ICCV)}, + year={2019} +} +``` + diff --git a/data_processing/detectron2/projects/TridentNet/configs/Base-TridentNet-Fast-C4.yaml b/data_processing/detectron2/projects/TridentNet/configs/Base-TridentNet-Fast-C4.yaml new file mode 100644 index 0000000..8c3d807 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/configs/Base-TridentNet-Fast-C4.yaml @@ -0,0 +1,29 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_trident_resnet_backbone" + ROI_HEADS: + NAME: "TridentRes5ROIHeads" + POSITIVE_FRACTION: 0.5 + BATCH_SIZE_PER_IMAGE: 128 + PROPOSAL_APPEND_GT: False + PROPOSAL_GENERATOR: + NAME: "TridentRPN" + RPN: + POST_NMS_TOPK_TRAIN: 500 + TRIDENT: + NUM_BRANCH: 3 + BRANCH_DILATIONS: [1, 2, 3] + TEST_BRANCH_IDX: 1 + TRIDENT_STAGE: "res4" +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_101_C4_3x.yaml b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_101_C4_3x.yaml new file mode 100644 index 0000000..bc83c2f --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_101_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-TridentNet-Fast-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_1x.yaml b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_1x.yaml new file mode 100644 index 0000000..fda2cb6 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-TridentNet-Fast-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 diff --git a/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_3x.yaml b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_3x.yaml new file mode 100644 index 0000000..ebf89d0 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/configs/tridentnet_fast_R_50_C4_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-TridentNet-Fast-C4.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 diff --git a/data_processing/detectron2/projects/TridentNet/train_net.py b/data_processing/detectron2/projects/TridentNet/train_net.py new file mode 100644 index 0000000..143289a --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/train_net.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +TridentNet Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" + +import os + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch +from detectron2.evaluation import COCOEvaluator + +from tridentnet import add_tridentnet_config + + +class Trainer(DefaultTrainer): + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + return COCOEvaluator(dataset_name, output_dir=output_folder) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_tridentnet_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + return res + + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/__init__.py b/data_processing/detectron2/projects/TridentNet/tridentnet/__init__.py new file mode 100644 index 0000000..abaa957 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import add_tridentnet_config +from .trident_backbone import ( + TridentBottleneckBlock, + build_trident_resnet_backbone, + make_trident_stage, +) +from .trident_rpn import TridentRPN +from .trident_rcnn import TridentRes5ROIHeads, TridentStandardROIHeads diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/config.py b/data_processing/detectron2/projects/TridentNet/tridentnet/config.py new file mode 100644 index 0000000..4b8732a --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/config.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from detectron2.config import CfgNode as CN + + +def add_tridentnet_config(cfg): + """ + Add config for tridentnet. + """ + _C = cfg + + _C.MODEL.TRIDENT = CN() + + # Number of branches for TridentNet. + _C.MODEL.TRIDENT.NUM_BRANCH = 3 + # Specify the dilations for each branch. + _C.MODEL.TRIDENT.BRANCH_DILATIONS = [1, 2, 3] + # Specify the stage for applying trident blocks. Default stage is Res4 according to the + # TridentNet paper. + _C.MODEL.TRIDENT.TRIDENT_STAGE = "res4" + # Specify the test branch index TridentNet Fast inference: + # - use -1 to aggregate results of all branches during inference. + # - otherwise, only using specified branch for fast inference. Recommended setting is + # to use the middle branch. + _C.MODEL.TRIDENT.TEST_BRANCH_IDX = 1 diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/trident_backbone.py b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_backbone.py new file mode 100644 index 0000000..7789bd2 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_backbone.py @@ -0,0 +1,220 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F + +from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm +from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase +from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock + +from .trident_conv import TridentConv + +__all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"] + + +class TridentBottleneckBlock(ResNetBlockBase): + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + num_branch=3, + dilations=(1, 2, 3), + concat_output=False, + test_branch_idx=-1, + ): + """ + Args: + num_branch (int): the number of branches in TridentNet. + dilations (tuple): the dilations of multiple branches in TridentNet. + concat_output (bool): if concatenate outputs of multiple branches in TridentNet. + Use 'True' for the last trident block. + """ + super().__init__(in_channels, out_channels, stride) + + assert num_branch == len(dilations) + + self.num_branch = num_branch + self.concat_output = concat_output + self.test_branch_idx = test_branch_idx + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv2 = TridentConv( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + paddings=dilations, + bias=False, + groups=num_groups, + dilations=dilations, + num_branch=num_branch, + test_branch_idx=test_branch_idx, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + if not isinstance(x, list): + x = [x] * num_branch + out = [self.conv1(b) for b in x] + out = [F.relu_(b) for b in out] + + out = self.conv2(out) + out = [F.relu_(b) for b in out] + + out = [self.conv3(b) for b in out] + + if self.shortcut is not None: + shortcut = [self.shortcut(b) for b in x] + else: + shortcut = x + + out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)] + out = [F.relu_(b) for b in out] + if self.concat_output: + out = torch.cat(out) + return out + + +def make_trident_stage(block_class, num_blocks, **kwargs): + """ + Create a resnet stage by creating many blocks for TridentNet. + """ + concat_output = [False] * (num_blocks - 1) + [True] + kwargs["concat_output_per_block"] = concat_output + return ResNet.make_stage(block_class, num_blocks, **kwargs) + + +@BACKBONE_REGISTRY.register() +def build_trident_resnet_backbone(cfg, input_shape): + """ + Create a ResNet instance from config for TridentNet. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + + if freeze_at >= 1: + for p in stem.parameters(): + p.requires_grad = False + stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) + + # fmt: off + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH + branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS + trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE + test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + + stages = [] + + res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5} + out_stage_idx = [res_stage_idx[f] for f in out_features] + trident_stage_idx = res_stage_idx[trident_stage] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "bottleneck_channels": bottleneck_channels, + "out_channels": out_channels, + "num_groups": num_groups, + "norm": norm, + "stride_in_1x1": stride_in_1x1, + "dilation": dilation, + } + if stage_idx == trident_stage_idx: + assert not deform_on_per_stage[ + idx + ], "Not support deformable conv in Trident blocks yet." + stage_kargs["block_class"] = TridentBottleneckBlock + stage_kargs["num_branch"] = num_branch + stage_kargs["dilations"] = branch_dilations + stage_kargs["test_branch_idx"] = test_branch_idx + stage_kargs.pop("dilation") + elif deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = ( + make_trident_stage(**stage_kargs) + if stage_idx == trident_stage_idx + else ResNet.make_stage(**stage_kargs) + ) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + + if freeze_at >= stage_idx: + for block in blocks: + block.freeze() + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features) diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/trident_conv.py b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_conv.py new file mode 100644 index 0000000..18d5b0b --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_conv.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + +from detectron2.layers.wrappers import _NewEmptyTensorOp + + +class TridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + paddings=0, + dilations=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(TridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if inputs[0].numel() == 0: + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape + return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs] + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups) + for input, dilation, padding in zip(inputs, self.dilations, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.stride, + self.paddings[self.test_branch_idx], + self.dilations[self.test_branch_idx], + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs + + def extra_repr(self): + tmpstr = "in_channels=" + str(self.in_channels) + tmpstr += ", out_channels=" + str(self.out_channels) + tmpstr += ", kernel_size=" + str(self.kernel_size) + tmpstr += ", num_branch=" + str(self.num_branch) + tmpstr += ", test_branch_idx=" + str(self.test_branch_idx) + tmpstr += ", stride=" + str(self.stride) + tmpstr += ", paddings=" + str(self.paddings) + tmpstr += ", dilations=" + str(self.dilations) + tmpstr += ", groups=" + str(self.groups) + tmpstr += ", bias=" + str(self.with_bias) + return tmpstr diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rcnn.py b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rcnn.py new file mode 100644 index 0000000..fc22c71 --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rcnn.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.layers import batched_nms +from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads +from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads +from detectron2.structures import Instances + + +def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image): + """ + Merge detection results from different branches of TridentNet. + Return detection results by applying non-maximum suppression (NMS) on bounding boxes + and keep the unsuppressed boxes and other instances (e.g mask) if any. + + Args: + instances (list[Instances]): A list of N * num_branch instances that store detection + results. Contain N images and each image has num_branch instances. + num_branch (int): Number of branches used for merging detection results for each image. + nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + topk_per_image (int): The number of top scoring detections to return. Set < 0 to return + all detections. + + Returns: + results: (list[Instances]): A list of N instances, one for each image in the batch, + that stores the topk most confidence detections after merging results from multiple + branches. + """ + if num_branch == 1: + return instances + + batch_size = len(instances) // num_branch + results = [] + for i in range(batch_size): + instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)]) + + # Apply per-class NMS + keep = batched_nms( + instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh + ) + keep = keep[:topk_per_image] + result = instance[keep] + + results.append(result) + + return results + + +@ROI_HEADS_REGISTRY.register() +class TridentRes5ROIHeads(Res5ROIHeads): + """ + The TridentNet ROIHeads in a typical "C4" R-CNN model. + See :class:`Res5ROIHeads`. + """ + + def __init__(self, cfg, input_shape): + super().__init__(cfg, input_shape) + + self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH + self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 + + def forward(self, images, features, proposals, targets=None): + """ + See :class:`Res5ROIHeads.forward`. + """ + num_branch = self.num_branch if self.training or not self.trident_fast else 1 + all_targets = targets * num_branch if targets is not None else None + pred_instances, losses = super().forward(images, features, proposals, all_targets) + del images, all_targets, targets + + if self.training: + return pred_instances, losses + else: + pred_instances = merge_branch_instances( + pred_instances, + num_branch, + self.box_predictor.test_nms_thresh, + self.box_predictor.test_topk_per_image, + ) + + return pred_instances, {} + + +@ROI_HEADS_REGISTRY.register() +class TridentStandardROIHeads(StandardROIHeads): + """ + The `StandardROIHeads` for TridentNet. + See :class:`StandardROIHeads`. + """ + + def __init__(self, cfg, input_shape): + super(TridentStandardROIHeads, self).__init__(cfg, input_shape) + + self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH + self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 + + def forward(self, images, features, proposals, targets=None): + """ + See :class:`Res5ROIHeads.forward`. + """ + # Use 1 branch if using trident_fast during inference. + num_branch = self.num_branch if self.training or not self.trident_fast else 1 + # Duplicate targets for all branches in TridentNet. + all_targets = targets * num_branch if targets is not None else None + pred_instances, losses = super().forward(images, features, proposals, all_targets) + del images, all_targets, targets + + if self.training: + return pred_instances, losses + else: + pred_instances = merge_branch_instances( + pred_instances, + num_branch, + self.box_predictor.test_nms_thresh, + self.box_predictor.test_topk_per_image, + ) + + return pred_instances, {} diff --git a/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rpn.py b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rpn.py new file mode 100644 index 0000000..f95fbbf --- /dev/null +++ b/data_processing/detectron2/projects/TridentNet/tridentnet/trident_rpn.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY +from detectron2.modeling.proposal_generator.rpn import RPN +from detectron2.structures import ImageList + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class TridentRPN(RPN): + """ + Trident RPN subnetwork. + """ + + def __init__(self, cfg, input_shape): + super(TridentRPN, self).__init__(cfg, input_shape) + + self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH + self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 + + def forward(self, images, features, gt_instances=None): + """ + See :class:`RPN.forward`. + """ + num_branch = self.num_branch if self.training or not self.trident_fast else 1 + # Duplicate images and gt_instances for all branches in TridentNet. + all_images = ImageList( + torch.cat([images.tensor] * num_branch), images.image_sizes * num_branch + ) + all_gt_instances = gt_instances * num_branch if gt_instances is not None else None + + return super(TridentRPN, self).forward(all_images, features, all_gt_instances) diff --git a/data_processing/detectron2/projects/ViTDet/README.md b/data_processing/detectron2/projects/ViTDet/README.md new file mode 100644 index 0000000..0a525e0 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/README.md @@ -0,0 +1,364 @@ +# ViTDet: Exploring Plain Vision Transformer Backbones for Object Detection + +Yanghao Li, Hanzi Mao, Ross Girshick†, Kaiming He† + +[[`arXiv`](https://arxiv.org/abs/2203.16527)] [[`BibTeX`](#CitingViTDet)] + +In this repository, we provide configs and models in Detectron2 for ViTDet as well as MViTv2 and Swin backbones with our implementation and settings as described in [ViTDet](https://arxiv.org/abs/2203.16527) paper. + + +## Pretrained Models + +### COCO + +#### Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
ViTDet, ViT-BIN1K, MAE0.3140.07910.951.645.9325346929model
ViTDet, ViT-LIN1K, MAE0.6030.12520.955.549.2325599698model
ViTDet, ViT-HIN1K, MAE1.0980.17831.556.750.2329145471model
+ +#### Cascade Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
Swin-BIN21K, sup0.3890.0778.753.946.2342979038model
Swin-LIN21K, sup0.5080.09712.655.047.2342979186model
MViTv2-BIN21K, sup0.4750.0908.955.648.1325820315model
MViTv2-LIN21K, sup0.8440.15719.755.748.3325607715model
MViTv2-HIN21K, sup1.6550.28518.4*55.948.3326187358model
ViTDet, ViT-BIN1K, MAE0.3620.08912.354.046.7325358525model
ViTDet, ViT-LIN1K, MAE0.6430.14222.357.650.0328021305model
ViTDet, ViT-HIN1K, MAE1.1370.19632.958.751.0328730692model
+ + +### LVIS + +#### Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
ViTDet, ViT-BIN1K, MAE0.3170.08514.440.238.2329225748model
ViTDet, ViT-LIN1K, MAE0.5760.13724.746.143.6329211570model
ViTDet, ViT-HIN1K, MAE1.0590.18635.349.146.0332434656model
+ +#### Cascade Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
Swin-BIN21K, sup0.3680.09011.544.039.6329222304model
Swin-LIN21K, sup0.4860.10513.846.041.4329222724model
MViTv2-BIN21K, sup0.4750.10011.846.342.0329477206model
MViTv2-LIN21K, sup0.8440.17221.049.444.2329661552model
MViTv2-HIN21K, sup1.6610.29021.3*49.544.1330445165model
ViTDet, ViT-BIN1K, MAE0.3560.09915.243.038.9329226874model
ViTDet, ViT-LIN1K, MAE0.6290.15024.949.244.5329042206model
ViTDet, ViT-HIN1K, MAE1.1000.20435.551.546.6332552778model
+ +Note: Unlike the system-level comparisons in the paper, these models use a lower resolution (1024 instead of 1280) and standard NMS (instead of soft NMS). As a result, they have slightly lower box and mask AP. + +We observed higher variance on LVIS evalution results compared to COCO. For example, the standard deviations of box AP and mask AP were 0.30% (compared to 0.10% on COCO) when we trained ViTDet, ViT-B five times with varying random seeds. + +The above models were trained and measured on 8-node with 64 NVIDIA A100 GPUs in total. *: Activation checkpointing is used. + + +## Training +All configs can be trained with: + +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py +``` +By default, we use 64 GPUs with batch size as 64 for training. + +## Evaluation +Model evaluation can be done similarly: +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint +``` + + +## Citing ViTDet + +If you use ViTDet, please use the following BibTeX entry. + +```BibTeX +@article{li2022exploring, + title={Exploring plain vision transformer backbones for object detection}, + author={Li, Yanghao and Mao, Hanzi and Girshick, Ross and He, Kaiming}, + journal={arXiv preprint arXiv:2203.16527}, + year={2022} +} +``` diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py new file mode 100644 index 0000000..9dba203 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py @@ -0,0 +1,95 @@ +from functools import partial +import torch.nn as nn +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling import MViT +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + FastRCNNConvFCHead, + CascadeROIHeads, +) + +from ..common.coco_loader_lsj import dataloader + +model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model +constants = model_zoo.get_config("common/data/constants.py").constants +model.pixel_mean = constants.imagenet_rgb256_mean +model.pixel_std = constants.imagenet_rgb256_std +model.input_format = "RGB" +model.backbone.bottom_up = L(MViT)( + embed_dim=96, + depth=24, + num_heads=1, + last_block_indexes=(1, 4, 20, 23), + residual_pooling=True, + drop_path_rate=0.4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + out_features=("scale2", "scale3", "scale4", "scale5"), +) +model.backbone.in_features = "${.bottom_up.out_features}" +model.backbone.square_pad = 1024 + +# New heads and LN +model.backbone.norm = "LN" # Use LN in FPN +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_B_in21k.pyth" + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} +optimizer.lr = 8e-5 diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py new file mode 100644 index 0000000..5770450 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py @@ -0,0 +1,39 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.depth = 80 +model.backbone.bottom_up.num_heads = 3 +model.backbone.bottom_up.last_block_indexes = (3, 11, 71, 79) +model.backbone.bottom_up.drop_path_rate = 0.6 +model.backbone.bottom_up.use_act_checkpoint = True + + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_H_in21k.pyth" + + +# 36 epochs +train.max_iter = 67500 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[ + 52500, + 62500, + 67500, + ], + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) +optimizer.lr = 1.6e-4 diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py new file mode 100644 index 0000000..c64f0c1 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py @@ -0,0 +1,22 @@ +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 144 +model.backbone.bottom_up.depth = 48 +model.backbone.bottom_up.num_heads = 2 +model.backbone.bottom_up.last_block_indexes = (1, 7, 43, 47) +model.backbone.bottom_up.drop_path_rate = 0.5 + + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_L_in21k.pyth" + +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py new file mode 100644 index 0000000..b2aad98 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py @@ -0,0 +1,50 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling import SwinTransformer + +from ..common.coco_loader_lsj import dataloader +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import model + +model.backbone.bottom_up = L(SwinTransformer)( + depths=[2, 2, 18, 2], + drop_path_rate=0.4, + embed_dim=128, + num_heads=[4, 8, 16, 32], +) +model.backbone.in_features = ("p0", "p1", "p2", "p3") +model.backbone.square_pad = 1024 + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/swin/swin_base_patch4_window7_224_22k.pth" + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Rescale schedule +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter + + +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.lr = 4e-5 +optimizer.weight_decay = 0.05 +optimizer.params.overrides = {"relative_position_bias_table": {"weight_decay": 0.0}} diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py new file mode 100644 index 0000000..60bc917 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py @@ -0,0 +1,15 @@ +from .cascade_mask_rcnn_swin_b_in21k_50ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.depths = [2, 2, 18, 2] +model.backbone.bottom_up.drop_path_rate = 0.4 +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.num_heads = [6, 12, 24, 48] + + +train.init_checkpoint = "detectron2://ImageNetPretrained/swin/swin_large_patch4_window7_224_22k.pth" diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000..95823ef --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,48 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + FastRCNNConvFCHead, + CascadeROIHeads, +) + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py new file mode 100644 index 0000000..e508a68 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py @@ -0,0 +1,33 @@ +from functools import partial + +from .cascade_mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.5 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None + +train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep +lr_multiplier.scheduler.milestones = [ + milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000..2743603 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,25 @@ +from functools import partial + +from .cascade_mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1024 +model.backbone.net.depth = 24 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 5, 11, 17, 23 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.8, num_layers=24) diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000..8fd36e9 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,40 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +from ..common.coco_loader_lsj import dataloader + + +model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth?matching_heuristics=True" +) + + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py new file mode 100644 index 0000000..7de96f0 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py @@ -0,0 +1,33 @@ +from functools import partial + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.5 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None + +train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep +lr_multiplier.scheduler.milestones = [ + milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000..0d193cb --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,25 @@ +from functools import partial + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1024 +model.backbone.net.depth = 24 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 5, 11, 17, 23 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.8, num_layers=24) diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py new file mode 100644 index 0000000..1cf9c3e --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py @@ -0,0 +1,48 @@ +from functools import partial +import torch.nn as nn + +from detectron2.config import LazyCall as L +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.data.samplers import RepeatFactorTrainingSampler +from detectron2.evaluation.lvis_evaluation import LVISEvaluator + +from ..COCO.cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + model, + train, + lr_multiplier, + optimizer, +) + +dataloader.train.dataset.names = "lvis_v1_train" +dataloader.train.sampler = L(RepeatFactorTrainingSampler)( + repeat_factors=L(RepeatFactorTrainingSampler.repeat_factors_from_category_frequency)( + dataset_dicts="${dataloader.train.dataset}", repeat_thresh=0.001 + ) +) +dataloader.test.dataset.names = "lvis_v1_val" +dataloader.evaluator = L(LVISEvaluator)( + dataset_name="${..test.dataset.names}", + max_dets_per_image=300, +) + +model.roi_heads.num_classes = 1203 +for i in range(3): + model.roi_heads.box_predictors[i].test_score_thresh = 0.02 + model.roi_heads.box_predictors[i].test_topk_per_image = 300 + model.roi_heads.box_predictors[i].use_sigmoid_ce = True + model.roi_heads.box_predictors[i].use_fed_loss = True + model.roi_heads.box_predictors[i].get_fed_loss_cls_weights = lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 + ) + +# Schedule +# 100 ep = 156250 iters * 64 images/iter / 100000 images/ep +train.max_iter = 156250 +train.eval_period = 30000 + +lr_multiplier.scheduler.milestones = [138889, 150463] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +optimizer.lr = 1e-4 diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_h_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_h_in21k_50ep.py new file mode 100644 index 0000000..084444b --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_h_in21k_50ep.py @@ -0,0 +1,25 @@ +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.depth = 80 +model.backbone.bottom_up.num_heads = 3 +model.backbone.bottom_up.last_block_indexes = (3, 11, 71, 79) +model.backbone.bottom_up.drop_path_rate = 0.6 +model.backbone.bottom_up.use_act_checkpoint = True + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_H_in21k.pyth" + +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +optimizer.lr = 2e-5 diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py new file mode 100644 index 0000000..779442c --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py @@ -0,0 +1,24 @@ +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 144 +model.backbone.bottom_up.depth = 48 +model.backbone.bottom_up.num_heads = 2 +model.backbone.bottom_up.last_block_indexes = (1, 7, 43, 47) +model.backbone.bottom_up.drop_path_rate = 0.5 + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_L_in21k.pyth" + +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +optimizer.lr = 4e-5 diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_b_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_b_in21k_50ep.py new file mode 100644 index 0000000..d18c925 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_b_in21k_50ep.py @@ -0,0 +1,49 @@ +from detectron2.config.lazy import LazyCall as L +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.data.samplers import RepeatFactorTrainingSampler +from detectron2.evaluation.lvis_evaluation import LVISEvaluator + +from ..COCO.cascade_mask_rcnn_swin_b_in21k_50ep import ( + dataloader, + model, + train, + lr_multiplier, + optimizer, +) + +dataloader.train.dataset.names = "lvis_v1_train" +dataloader.train.sampler = L(RepeatFactorTrainingSampler)( + repeat_factors=L(RepeatFactorTrainingSampler.repeat_factors_from_category_frequency)( + dataset_dicts="${dataloader.train.dataset}", repeat_thresh=0.001 + ) +) +dataloader.test.dataset.names = "lvis_v1_val" +dataloader.evaluator = L(LVISEvaluator)( + dataset_name="${..test.dataset.names}", + max_dets_per_image=300, +) + +model.backbone.bottom_up.drop_path_rate = 0.3 + +model.roi_heads.num_classes = 1203 +for i in range(3): + model.roi_heads.box_predictors[i].test_score_thresh = 0.02 + model.roi_heads.box_predictors[i].test_topk_per_image = 300 + model.roi_heads.box_predictors[i].use_sigmoid_ce = True + model.roi_heads.box_predictors[i].use_fed_loss = True + model.roi_heads.box_predictors[i].get_fed_loss_cls_weights = lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 + ) + +# Schedule +# 100 ep = 156250 iters * 64 images/iter / 100000 images/ep +# 100 ep -> 50 ep as the model achieves better performance with 50 epochs +train.max_iter = 156250 // 2 +train.eval_period = 30000 + +lr_multiplier.scheduler.milestones = [milestone // 2 for milestone in [138889, 150463]] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +# Optimized hyperparams +optimizer.lr = 1e-4 diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_l_in21k_50ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_l_in21k_50ep.py new file mode 100644 index 0000000..9e22e3b --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_swin_l_in21k_50ep.py @@ -0,0 +1,12 @@ +from .cascade_mask_rcnn_swin_b_in21k_50ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.num_heads = [6, 12, 24, 48] + +train.init_checkpoint = "detectron2://ImageNetPretrained/swin/swin_large_patch4_window7_224_22k.pth" diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_b_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000..8115224 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,51 @@ +from detectron2.config import LazyCall as L +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead, CascadeROIHeads + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + num_classes=1203, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + num_classes="${...num_classes}", + test_score_thresh=0.02, + test_topk_per_image=300, + cls_agnostic_bbox_reg=True, + use_sigmoid_ce=True, + use_fed_loss=True, + get_fed_loss_cls_weights=lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 + ), + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_h_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_h_100ep.py new file mode 100644 index 0000000..68bec57 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_h_100ep.py @@ -0,0 +1,51 @@ +from detectron2.config import LazyCall as L +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead, CascadeROIHeads + +from .mask_rcnn_vitdet_h_100ep import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + num_classes=1203, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + num_classes="${...num_classes}", + test_score_thresh=0.02, + test_topk_per_image=300, + cls_agnostic_bbox_reg=True, + use_sigmoid_ce=True, + use_fed_loss=True, + get_fed_loss_cls_weights=lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 + ), + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_l_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000..ebaf526 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,51 @@ +from detectron2.config import LazyCall as L +from detectron2.data.detection_utils import get_fed_loss_cls_weights +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead, CascadeROIHeads + +from .mask_rcnn_vitdet_l_100ep import ( + dataloader, + lr_multiplier, + model, + optimizer, + train, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + num_classes=1203, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + num_classes="${...num_classes}", + test_score_thresh=0.02, + test_topk_per_image=300, + cls_agnostic_bbox_reg=True, + use_sigmoid_ce=True, + use_fed_loss=True, + get_fed_loss_cls_weights=lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 + ), + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_b_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000..ef90545 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,44 @@ +from detectron2.config import LazyCall as L +from detectron2.data.samplers import RepeatFactorTrainingSampler +from detectron2.evaluation.lvis_evaluation import LVISEvaluator +from detectron2.data.detection_utils import get_fed_loss_cls_weights + +from ..COCO.mask_rcnn_vitdet_b_100ep import ( + dataloader, + model, + train, + lr_multiplier, + optimizer, +) + +dataloader.train.dataset.names = "lvis_v1_train" +dataloader.train.sampler = L(RepeatFactorTrainingSampler)( + repeat_factors=L(RepeatFactorTrainingSampler.repeat_factors_from_category_frequency)( + dataset_dicts="${dataloader.train.dataset}", repeat_thresh=0.001 + ) +) +dataloader.test.dataset.names = "lvis_v1_val" +dataloader.evaluator = L(LVISEvaluator)( + dataset_name="${..test.dataset.names}", + max_dets_per_image=300, +) + +model.roi_heads.num_classes = 1203 +model.roi_heads.box_predictor.test_score_thresh = 0.02 +model.roi_heads.box_predictor.test_topk_per_image = 300 +model.roi_heads.box_predictor.use_sigmoid_ce = True +model.roi_heads.box_predictor.use_fed_loss = True +model.roi_heads.box_predictor.get_fed_loss_cls_weights = lambda: get_fed_loss_cls_weights( + dataloader.train.dataset.names, 0.5 +) + +# Schedule +# 100 ep = 156250 iters * 64 images/iter / 100000 images/ep +train.max_iter = 156250 +train.eval_period = 30000 + +lr_multiplier.scheduler.milestones = [138889, 150463] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +optimizer.lr = 2e-4 diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_h_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_h_100ep.py new file mode 100644 index 0000000..0f99bad --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_h_100ep.py @@ -0,0 +1,30 @@ +from functools import partial + +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + + +optimizer.lr = 1e-4 +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None diff --git a/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_l_100ep.py b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000..15d8792 --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/LVIS/mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,26 @@ +from functools import partial + +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +train.init_checkpoint = ( + "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth?matching_heuristics=True" +) + +model.backbone.net.embed_dim = 1024 +model.backbone.net.depth = 24 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 5, 11, 17, 23 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.8, num_layers=24) diff --git a/data_processing/detectron2/projects/ViTDet/configs/common/coco_loader_lsj.py b/data_processing/detectron2/projects/ViTDet/configs/common/coco_loader_lsj.py new file mode 100644 index 0000000..e6c2f1e --- /dev/null +++ b/data_processing/detectron2/projects/ViTDet/configs/common/coco_loader_lsj.py @@ -0,0 +1,22 @@ +import detectron2.data.transforms as T +from detectron2 import model_zoo +from detectron2.config import LazyCall as L + +# Data using LSJ +image_size = 1024 +dataloader = model_zoo.get_config("common/data/coco.py").dataloader +dataloader.train.mapper.augmentations = [ + L(T.RandomFlip)(horizontal=True), # flip first + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False), +] +dataloader.train.mapper.image_format = "RGB" +dataloader.train.total_batch_size = 64 +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True + +dataloader.test.mapper.augmentations = [ + L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), +] diff --git a/data_processing/detectron2/setup.cfg b/data_processing/detectron2/setup.cfg new file mode 100644 index 0000000..f127d7b --- /dev/null +++ b/data_processing/detectron2/setup.cfg @@ -0,0 +1,26 @@ +[isort] +line_length=100 +multi_line_output=3 +include_trailing_comma=True +known_standard_library=numpy,setuptools,mock +skip=./datasets,docs +skip_glob=*/__init__.py,**/configs/**,**/tests/config/** +known_myself=detectron2 +known_third_party=fvcore,matplotlib,cv2,torch,torchvision,PIL,pycocotools,yacs,termcolor,cityscapesscripts,tabulate,tqdm,scipy,lvis,psutil,pkg_resources,caffe2,onnx,panopticapi,black,isort,av,iopath,omegaconf,hydra,yaml,pydoc,submitit,cloudpickle,packaging,timm,pandas,fairscale,pytorch3d,pytorch_lightning +no_lines_before=STDLIB,THIRDPARTY +sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER +default_section=FIRSTPARTY + +[mypy] +python_version=3.7 +ignore_missing_imports = True +warn_unused_configs = True +disallow_untyped_defs = True +check_untyped_defs = True +warn_unused_ignores = True +warn_redundant_casts = True +show_column_numbers = True +follow_imports = silent +allow_redefinition = True +; Require all functions to be annotated +disallow_incomplete_defs = True diff --git a/data_processing/detectron2/setup.py b/data_processing/detectron2/setup.py new file mode 100644 index 0000000..559cf54 --- /dev/null +++ b/data_processing/detectron2/setup.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import glob +import os +import shutil +from os import path +from setuptools import find_packages, setup +from typing import List +import torch +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] +assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" + + +def get_version(): + init_py_path = path.join(path.abspath(path.dirname(__file__)), "detectron2", "__init__.py") + init_py = open(init_py_path, "r").readlines() + version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] + version = version_line.split("=")[-1].strip().strip("'\"") + + # The following is used to build release packages. + # Users should never use it. + suffix = os.getenv("D2_VERSION_SUFFIX", "") + version = version + suffix + if os.getenv("BUILD_NIGHTLY", "0") == "1": + from datetime import datetime + + date_str = datetime.today().strftime("%y%m%d") + version = version + ".dev" + date_str + + new_init_py = [l for l in init_py if not l.startswith("__version__")] + new_init_py.append('__version__ = "{}"\n'.format(version)) + with open(init_py_path, "w") as f: + f.write("".join(new_init_py)) + return version + + +def get_extensions(): + this_dir = path.dirname(path.abspath(__file__)) + extensions_dir = path.join(this_dir, "detectron2", "layers", "csrc") + + main_source = path.join(extensions_dir, "vision.cpp") + sources = glob.glob(path.join(extensions_dir, "**", "*.cpp")) + + from torch.utils.cpp_extension import ROCM_HOME + + is_rocm_pytorch = ( + True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + ) + if is_rocm_pytorch: + assert torch_ver >= [1, 8], "ROCM support requires PyTorch >= 1.8!" + + # common code between cuda and rocm platforms, for hipify version [1,0,0] and later. + source_cuda = glob.glob(path.join(extensions_dir, "**", "*.cu")) + glob.glob( + path.join(extensions_dir, "*.cu") + ) + sources = [main_source] + sources + + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv( + "FORCE_CUDA", "0" + ) == "1": + extension = CUDAExtension + sources += source_cuda + + if not is_rocm_pytorch: + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-O3", + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + define_macros += [("WITH_HIP", None)] + extra_compile_args["nvcc"] = [] + + if torch_ver < [1, 7]: + # supported by https://github.com/pytorch/pytorch/pull/43931 + CC = os.environ.get("CC", None) + if CC is not None: + extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "detectron2._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +def get_model_zoo_configs() -> List[str]: + """ + Return a list of configs to include in package for model zoo. Copy over these configs inside + detectron2/model_zoo. + """ + + # Use absolute paths while symlinking. + source_configs_dir = path.join(path.dirname(path.realpath(__file__)), "configs") + destination = path.join( + path.dirname(path.realpath(__file__)), "detectron2", "model_zoo", "configs" + ) + # Symlink the config directory inside package to have a cleaner pip install. + + # Remove stale symlink/directory from a previous build. + if path.exists(source_configs_dir): + if path.islink(destination): + os.unlink(destination) + elif path.isdir(destination): + shutil.rmtree(destination) + + if not path.exists(destination): + try: + os.symlink(source_configs_dir, destination) + except OSError: + # Fall back to copying if symlink fails: ex. on Windows. + shutil.copytree(source_configs_dir, destination) + + config_paths = glob.glob("configs/**/*.yaml", recursive=True) + glob.glob( + "configs/**/*.py", recursive=True + ) + return config_paths + + +# For projects that are relative small and provide features that are very close +# to detectron2's core functionalities, we install them under detectron2.projects +PROJECTS = { + "detectron2.projects.point_rend": "projects/PointRend/point_rend", + "detectron2.projects.deeplab": "projects/DeepLab/deeplab", + "detectron2.projects.panoptic_deeplab": "projects/Panoptic-DeepLab/panoptic_deeplab", +} + +setup( + name="detectron2", + version=get_version(), + author="FAIR", + url="https://github.com/facebookresearch/detectron2", + description="Detectron2 is FAIR's next-generation research " + "platform for object detection and segmentation.", + packages=find_packages(exclude=("configs", "tests*")) + list(PROJECTS.keys()), + package_dir=PROJECTS, + package_data={"detectron2.model_zoo": get_model_zoo_configs()}, + python_requires=">=3.7", + install_requires=[ + # These dependencies are not pure-python. + # In general, avoid adding dependencies that are not pure-python because they are not + # guaranteed to be installable by `pip install` on all platforms. + "Pillow>=7.1", # or use pillow-simd for better performance + "matplotlib", # TODO move it to optional after we add opencv visualization + "pycocotools>=2.0.2", # corresponds to https://github.com/ppwwyyxx/cocoapi + # Do not add opencv here. Just like pytorch, user should install + # opencv themselves, preferrably by OS's package manager, or by + # choosing the proper pypi package name at https://github.com/skvark/opencv-python + # Also, avoid adding dependencies that transitively depend on pytorch or opencv. + # ------------------------------------------------------------ + # The following are pure-python dependencies that should be easily installable. + # But still be careful when adding more: fewer people are able to use the software + # with every new dependency added. + "termcolor>=1.1", + "yacs>=0.1.8", + "tabulate", + "cloudpickle", + "tqdm>4.29.0", + "tensorboard", + # Lock version of fvcore/iopath because they may have breaking changes + # NOTE: when updating fvcore/iopath version, make sure fvcore depends + # on compatible version of iopath. + "fvcore>=0.1.5,<0.1.6", # required like this to make it pip installable + "iopath>=0.1.7,<0.1.10", + "dataclasses; python_version<'3.7'", + "omegaconf>=2.1", + "hydra-core>=1.1", + "black", + "packaging", + # NOTE: When adding new dependencies, if it is required at import time (in addition + # to runtime), it probably needs to appear in docs/requirements.txt, or as a mock + # in docs/conf.py + ], + extras_require={ + # optional dependencies, required by some features + "all": [ + "fairscale", + "timm", # Used by a few ViT models. + "scipy>1.5.1", + "shapely", + "pygments>=2.2", + "psutil", + "panopticapi @ https://github.com/cocodataset/panopticapi/archive/master.zip", + ], + # dev dependencies. Install them by `pip install 'detectron2[dev]'` + "dev": [ + "flake8==3.8.1", + "isort==4.3.21", + "flake8-bugbear", + "flake8-comprehensions", + "black==22.3.0", + ], + }, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/data_processing/detectron2/tests/README.md b/data_processing/detectron2/tests/README.md new file mode 100644 index 0000000..f560384 --- /dev/null +++ b/data_processing/detectron2/tests/README.md @@ -0,0 +1,9 @@ +## Unit Tests + +To run the unittests, do: +``` +cd detectron2 +python -m unittest discover -v -s ./tests +``` + +There are also end-to-end inference & training tests, in [dev/run_*_tests.sh](../dev). diff --git a/data_processing/detectron2/tests/__init__.py b/data_processing/detectron2/tests/__init__.py new file mode 100644 index 0000000..9020c2d --- /dev/null +++ b/data_processing/detectron2/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/data_processing/detectron2/tests/config/dir1/bad_import.py b/data_processing/detectron2/tests/config/dir1/bad_import.py new file mode 100644 index 0000000..d7452c4 --- /dev/null +++ b/data_processing/detectron2/tests/config/dir1/bad_import.py @@ -0,0 +1,2 @@ +# import from directory is not allowed +from . import dir1a diff --git a/data_processing/detectron2/tests/config/dir1/bad_import2.py b/data_processing/detectron2/tests/config/dir1/bad_import2.py new file mode 100644 index 0000000..085a4df --- /dev/null +++ b/data_processing/detectron2/tests/config/dir1/bad_import2.py @@ -0,0 +1 @@ +from .does_not_exist import x diff --git a/data_processing/detectron2/tests/config/dir1/dir1_a.py b/data_processing/detectron2/tests/config/dir1/dir1_a.py new file mode 100644 index 0000000..a939955 --- /dev/null +++ b/data_processing/detectron2/tests/config/dir1/dir1_a.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +dir1a_str = "base_a_1" +dir1a_dict = {"a": 1, "b": 2} diff --git a/data_processing/detectron2/tests/config/dir1/dir1_b.py b/data_processing/detectron2/tests/config/dir1/dir1_b.py new file mode 100644 index 0000000..2dcb54c --- /dev/null +++ b/data_processing/detectron2/tests/config/dir1/dir1_b.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import LazyConfig + +# equivalent to relative import +dir1a_str, dir1a_dict = LazyConfig.load_rel("dir1_a.py", ("dir1a_str", "dir1a_dict")) + +dir1b_str = dir1a_str + "_from_b" +dir1b_dict = dir1a_dict + +# Every import is a reload: not modified by other config files +assert dir1a_dict.a == 1 diff --git a/data_processing/detectron2/tests/config/dir1/load_rel.py b/data_processing/detectron2/tests/config/dir1/load_rel.py new file mode 100644 index 0000000..22d10db --- /dev/null +++ b/data_processing/detectron2/tests/config/dir1/load_rel.py @@ -0,0 +1,5 @@ +# test that load_rel can work +from detectron2.config import LazyConfig + +x = LazyConfig.load_rel("dir1_a.py", "dir1a_dict") +assert x["a"] == 1 diff --git a/data_processing/detectron2/tests/config/root_cfg.py b/data_processing/detectron2/tests/config/root_cfg.py new file mode 100644 index 0000000..33d1d4b --- /dev/null +++ b/data_processing/detectron2/tests/config/root_cfg.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from itertools import count + +from detectron2.config import LazyCall as L + +from .dir1.dir1_a import dir1a_dict, dir1a_str + +dir1a_dict.a = "modified" + +# modification above won't affect future imports +from .dir1.dir1_b import dir1b_dict, dir1b_str + + +lazyobj = L(count)(x=dir1a_str, y=dir1b_str) diff --git a/data_processing/detectron2/tests/config/test_instantiate_config.py b/data_processing/detectron2/tests/config/test_instantiate_config.py new file mode 100644 index 0000000..6b72894 --- /dev/null +++ b/data_processing/detectron2/tests/config/test_instantiate_config.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import tempfile +import unittest +import yaml +from omegaconf import OmegaConf +from omegaconf import __version__ as oc_version +from dataclasses import dataclass + +from detectron2.config import LazyConfig, instantiate, LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.utils.testing import reload_lazy_config + +OC_VERSION = tuple(int(x) for x in oc_version.split(".")[:2]) + + +class TestClass: + def __init__(self, int_arg, list_arg=None, dict_arg=None, extra_arg=None): + self.int_arg = int_arg + self.list_arg = list_arg + self.dict_arg = dict_arg + self.extra_arg = extra_arg + + def __call__(self, call_arg): + return call_arg + self.int_arg + + +@unittest.skipIf(OC_VERSION < (2, 1), "omegaconf version too old") +class TestConstruction(unittest.TestCase): + def test_basic_construct(self): + cfg = L(TestClass)( + int_arg=3, + list_arg=[10], + dict_arg={}, + extra_arg=L(TestClass)(int_arg=4, list_arg="${..list_arg}"), + ) + + for x in [cfg, reload_lazy_config(cfg)]: + obj = instantiate(x) + self.assertIsInstance(obj, TestClass) + self.assertEqual(obj.int_arg, 3) + self.assertEqual(obj.extra_arg.int_arg, 4) + self.assertEqual(obj.extra_arg.list_arg, obj.list_arg) + + # Test interpolation + x.extra_arg.list_arg = [5] + obj = instantiate(x) + self.assertIsInstance(obj, TestClass) + self.assertEqual(obj.extra_arg.list_arg, [5]) + + def test_instantiate_other_obj(self): + # do nothing for other obj + self.assertEqual(instantiate(5), 5) + x = [3, 4, 5] + self.assertEqual(instantiate(x), x) + x = TestClass(1) + self.assertIs(instantiate(x), x) + x = {"xx": "yy"} + self.assertIs(instantiate(x), x) + + def test_instantiate_lazy_target(self): + # _target_ is result of instantiate + objconf = L(L(len)(int_arg=3))(call_arg=4) + objconf._target_._target_ = TestClass + self.assertEqual(instantiate(objconf), 7) + + def test_instantiate_list(self): + lst = [1, 2, L(TestClass)(int_arg=1)] + x = L(TestClass)(int_arg=lst) # list as an argument should be recursively instantiated + x = instantiate(x).int_arg + self.assertEqual(x[:2], [1, 2]) + self.assertIsInstance(x[2], TestClass) + self.assertEqual(x[2].int_arg, 1) + + def test_instantiate_dataclass(self): + cfg = L(ShapeSpec)(channels=1, width=3) + # Test original cfg as well as serialization + for x in [cfg, reload_lazy_config(cfg)]: + obj = instantiate(x) + self.assertIsInstance(obj, ShapeSpec) + self.assertEqual(obj.channels, 1) + self.assertEqual(obj.height, None) + + def test_instantiate_dataclass_as_subconfig(self): + cfg = L(TestClass)(int_arg=1, extra_arg=ShapeSpec(channels=1, width=3)) + # Test original cfg as well as serialization + for x in [cfg, reload_lazy_config(cfg)]: + obj = instantiate(x) + self.assertIsInstance(obj.extra_arg, ShapeSpec) + self.assertEqual(obj.extra_arg.channels, 1) + self.assertEqual(obj.extra_arg.height, None) + + def test_bad_lazycall(self): + with self.assertRaises(Exception): + L(3) + + def test_interpolation(self): + cfg = L(TestClass)(int_arg=3, extra_arg="${int_arg}") + + cfg.int_arg = 4 + obj = instantiate(cfg) + self.assertEqual(obj.extra_arg, 4) + + # Test that interpolation still works after serialization + cfg = reload_lazy_config(cfg) + cfg.int_arg = 5 + obj = instantiate(cfg) + self.assertEqual(obj.extra_arg, 5) diff --git a/data_processing/detectron2/tests/config/test_lazy_config.py b/data_processing/detectron2/tests/config/test_lazy_config.py new file mode 100644 index 0000000..ff68143 --- /dev/null +++ b/data_processing/detectron2/tests/config/test_lazy_config.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +import unittest +import tempfile +from itertools import count + +from detectron2.config import LazyConfig, LazyCall as L +from omegaconf import DictConfig + + +class TestLazyPythonConfig(unittest.TestCase): + def setUp(self): + self.curr_dir = os.path.dirname(__file__) + self.root_filename = os.path.join(self.curr_dir, "root_cfg.py") + + def test_load(self): + cfg = LazyConfig.load(self.root_filename) + + self.assertEqual(cfg.dir1a_dict.a, "modified") + self.assertEqual(cfg.dir1b_dict.a, 1) + self.assertEqual(cfg.lazyobj.x, "base_a_1") + + cfg.lazyobj.x = "new_x" + # reload + cfg = LazyConfig.load(self.root_filename) + self.assertEqual(cfg.lazyobj.x, "base_a_1") + + def test_save_load(self): + cfg = LazyConfig.load(self.root_filename) + with tempfile.TemporaryDirectory(prefix="detectron2") as d: + fname = os.path.join(d, "test_config.yaml") + LazyConfig.save(cfg, fname) + cfg2 = LazyConfig.load(fname) + + self.assertEqual(cfg2.lazyobj._target_, "itertools.count") + self.assertEqual(cfg.lazyobj._target_, count) + cfg2.lazyobj.pop("_target_") + cfg.lazyobj.pop("_target_") + # the rest are equal + self.assertEqual(cfg, cfg2) + + def test_failed_save(self): + cfg = DictConfig({"x": lambda: 3}, flags={"allow_objects": True}) + with tempfile.TemporaryDirectory(prefix="detectron2") as d: + fname = os.path.join(d, "test_config.yaml") + LazyConfig.save(cfg, fname) + self.assertTrue(os.path.exists(fname)) + self.assertTrue(os.path.exists(fname + ".pkl")) + + def test_overrides(self): + cfg = LazyConfig.load(self.root_filename) + LazyConfig.apply_overrides(cfg, ["lazyobj.x=123", 'dir1b_dict.a="123"']) + self.assertEqual(cfg.dir1b_dict.a, "123") + self.assertEqual(cfg.lazyobj.x, 123) + + LazyConfig.apply_overrides(cfg, ["dir1b_dict.a=abc"]) + self.assertEqual(cfg.dir1b_dict.a, "abc") + + def test_invalid_overrides(self): + cfg = LazyConfig.load(self.root_filename) + with self.assertRaises(KeyError): + LazyConfig.apply_overrides(cfg, ["lazyobj.x.xxx=123"]) + + def test_to_py(self): + cfg = LazyConfig.load(self.root_filename) + cfg.lazyobj.x = {"a": 1, "b": 2, "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]})} + cfg.list = ["a", 1, "b", 3.2] + py_str = LazyConfig.to_py(cfg) + expected = """cfg.dir1a_dict.a = "modified" +cfg.dir1a_dict.b = 2 +cfg.dir1b_dict.a = 1 +cfg.dir1b_dict.b = 2 +cfg.lazyobj = itertools.count( + x={ + "a": 1, + "b": 2, + "c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}), + }, + y="base_a_1_from_b", +) +cfg.list = ["a", 1, "b", 3.2] +""" + self.assertEqual(py_str, expected) + + def test_bad_import(self): + file = os.path.join(self.curr_dir, "dir1", "bad_import.py") + with self.assertRaisesRegex(ImportError, "relative import"): + LazyConfig.load(file) + + def test_bad_import2(self): + file = os.path.join(self.curr_dir, "dir1", "bad_import2.py") + with self.assertRaisesRegex(ImportError, "not exist"): + LazyConfig.load(file) + + def test_load_rel(self): + file = os.path.join(self.curr_dir, "dir1", "load_rel.py") + cfg = LazyConfig.load(file) + self.assertIn("x", cfg) diff --git a/data_processing/detectron2/tests/config/test_yacs_config.py b/data_processing/detectron2/tests/config/test_yacs_config.py new file mode 100644 index 0000000..01dd695 --- /dev/null +++ b/data_processing/detectron2/tests/config/test_yacs_config.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + + +import os +import tempfile +import unittest +import torch +from omegaconf import OmegaConf + +from detectron2 import model_zoo +from detectron2.config import configurable, downgrade_config, get_cfg, upgrade_config +from detectron2.layers import ShapeSpec +from detectron2.modeling import build_model + +_V0_CFG = """ +MODEL: + RPN_HEAD: + NAME: "TEST" +VERSION: 0 +""" + +_V1_CFG = """ +MODEL: + WEIGHT: "/path/to/weight" +""" + + +class TestConfigVersioning(unittest.TestCase): + def test_upgrade_downgrade_consistency(self): + cfg = get_cfg() + # check that custom is preserved + cfg.USER_CUSTOM = 1 + + down = downgrade_config(cfg, to_version=0) + up = upgrade_config(down) + self.assertTrue(up == cfg) + + def _merge_cfg_str(self, cfg, merge_str): + f = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + try: + f.write(merge_str) + f.close() + cfg.merge_from_file(f.name) + finally: + os.remove(f.name) + return cfg + + def test_auto_upgrade(self): + cfg = get_cfg() + latest_ver = cfg.VERSION + cfg.USER_CUSTOM = 1 + + self._merge_cfg_str(cfg, _V0_CFG) + + self.assertEqual(cfg.MODEL.RPN.HEAD_NAME, "TEST") + self.assertEqual(cfg.VERSION, latest_ver) + + def test_guess_v1(self): + cfg = get_cfg() + latest_ver = cfg.VERSION + self._merge_cfg_str(cfg, _V1_CFG) + self.assertEqual(cfg.VERSION, latest_ver) + + +class _TestClassA(torch.nn.Module): + @configurable + def __init__(self, arg1, arg2, arg3=3): + super().__init__() + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 + assert arg1 == 1 + assert arg2 == 2 + assert arg3 == 3 + + @classmethod + def from_config(cls, cfg): + args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} + return args + + +class _TestClassB(_TestClassA): + @configurable + def __init__(self, input_shape, arg1, arg2, arg3=3): + """ + Doc of _TestClassB + """ + assert input_shape == "shape" + super().__init__(arg1, arg2, arg3) + + @classmethod + def from_config(cls, cfg, input_shape): # test extra positional arg in from_config + args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} + args["input_shape"] = input_shape + return args + + +class _LegacySubClass(_TestClassB): + # an old subclass written in cfg style + def __init__(self, cfg, input_shape, arg4=4): + super().__init__(cfg, input_shape) + assert self.arg1 == 1 + assert self.arg2 == 2 + assert self.arg3 == 3 + + +class _NewSubClassNewInit(_TestClassB): + # test new subclass with a new __init__ + @configurable + def __init__(self, input_shape, arg4=4, **kwargs): + super().__init__(input_shape, **kwargs) + assert self.arg1 == 1 + assert self.arg2 == 2 + assert self.arg3 == 3 + + +class _LegacySubClassNotCfg(_TestClassB): + # an old subclass written in cfg style, but argument is not called "cfg" + def __init__(self, config, input_shape): + super().__init__(config, input_shape) + assert self.arg1 == 1 + assert self.arg2 == 2 + assert self.arg3 == 3 + + +class _TestClassC(_TestClassB): + @classmethod + def from_config(cls, cfg, input_shape, **kwargs): # test extra kwarg overwrite + args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} + args["input_shape"] = input_shape + args.update(kwargs) + return args + + +class _TestClassD(_TestClassA): + @configurable + def __init__(self, input_shape: ShapeSpec, arg1: int, arg2, arg3=3): + assert input_shape == "shape" + super().__init__(arg1, arg2, arg3) + + # _TestClassA.from_config does not have input_shape args. + # Test whether input_shape will be forwarded to __init__ + + +@configurable(from_config=lambda cfg, arg2: {"arg1": cfg.ARG1, "arg2": arg2, "arg3": cfg.ARG3}) +def _test_func(arg1, arg2=2, arg3=3, arg4=4): + return arg1, arg2, arg3, arg4 + + +class TestConfigurable(unittest.TestCase): + def testInitWithArgs(self): + _ = _TestClassA(arg1=1, arg2=2, arg3=3) + _ = _TestClassB("shape", arg1=1, arg2=2) + _ = _TestClassC("shape", arg1=1, arg2=2) + _ = _TestClassD("shape", arg1=1, arg2=2, arg3=3) + + def testPatchedAttr(self): + self.assertTrue("Doc" in _TestClassB.__init__.__doc__) + self.assertEqual(_TestClassD.__init__.__annotations__["arg1"], int) + + def testInitWithCfg(self): + cfg = get_cfg() + cfg.ARG1 = 1 + cfg.ARG2 = 2 + cfg.ARG3 = 3 + _ = _TestClassA(cfg) + _ = _TestClassB(cfg, input_shape="shape") + _ = _TestClassC(cfg, input_shape="shape") + _ = _TestClassD(cfg, input_shape="shape") + _ = _LegacySubClass(cfg, input_shape="shape") + _ = _NewSubClassNewInit(cfg, input_shape="shape") + _ = _LegacySubClassNotCfg(cfg, input_shape="shape") + with self.assertRaises(TypeError): + # disallow forwarding positional args to __init__ since it's prone to errors + _ = _TestClassD(cfg, "shape") + + # call with kwargs instead + _ = _TestClassA(cfg=cfg) + _ = _TestClassB(cfg=cfg, input_shape="shape") + _ = _TestClassC(cfg=cfg, input_shape="shape") + _ = _TestClassD(cfg=cfg, input_shape="shape") + _ = _LegacySubClass(cfg=cfg, input_shape="shape") + _ = _NewSubClassNewInit(cfg=cfg, input_shape="shape") + _ = _LegacySubClassNotCfg(config=cfg, input_shape="shape") + + def testInitWithCfgOverwrite(self): + cfg = get_cfg() + cfg.ARG1 = 1 + cfg.ARG2 = 999 # wrong config + with self.assertRaises(AssertionError): + _ = _TestClassA(cfg, arg3=3) + + # overwrite arg2 with correct config later: + _ = _TestClassA(cfg, arg2=2, arg3=3) + _ = _TestClassB(cfg, input_shape="shape", arg2=2, arg3=3) + _ = _TestClassC(cfg, input_shape="shape", arg2=2, arg3=3) + _ = _TestClassD(cfg, input_shape="shape", arg2=2, arg3=3) + + # call with kwargs cfg=cfg instead + _ = _TestClassA(cfg=cfg, arg2=2, arg3=3) + _ = _TestClassB(cfg=cfg, input_shape="shape", arg2=2, arg3=3) + _ = _TestClassC(cfg=cfg, input_shape="shape", arg2=2, arg3=3) + _ = _TestClassD(cfg=cfg, input_shape="shape", arg2=2, arg3=3) + + def testInitWithCfgWrongArgs(self): + cfg = get_cfg() + cfg.ARG1 = 1 + cfg.ARG2 = 2 + with self.assertRaises(TypeError): + _ = _TestClassB(cfg, "shape", not_exist=1) + with self.assertRaises(TypeError): + _ = _TestClassC(cfg, "shape", not_exist=1) + with self.assertRaises(TypeError): + _ = _TestClassD(cfg, "shape", not_exist=1) + + def testBadClass(self): + class _BadClass1: + @configurable + def __init__(self, a=1, b=2): + pass + + class _BadClass2: + @configurable + def __init__(self, a=1, b=2): + pass + + def from_config(self, cfg): # noqa + pass + + class _BadClass3: + @configurable + def __init__(self, a=1, b=2): + pass + + # bad name: must be cfg + @classmethod + def from_config(cls, config): # noqa + pass + + with self.assertRaises(AttributeError): + _ = _BadClass1(a=1) + + with self.assertRaises(TypeError): + _ = _BadClass2(a=1) + + with self.assertRaises(TypeError): + _ = _BadClass3(get_cfg()) + + def testFuncWithCfg(self): + cfg = get_cfg() + cfg.ARG1 = 10 + cfg.ARG3 = 30 + + self.assertEqual(_test_func(1), (1, 2, 3, 4)) + with self.assertRaises(TypeError): + _test_func(cfg) + self.assertEqual(_test_func(cfg, arg2=2), (10, 2, 30, 4)) + self.assertEqual(_test_func(cfg, arg1=100, arg2=20), (100, 20, 30, 4)) + self.assertEqual(_test_func(cfg, arg1=100, arg2=20, arg4=40), (100, 20, 30, 40)) + + self.assertTrue(callable(_test_func.from_config)) + + def testOmegaConf(self): + cfg = model_zoo.get_config("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml") + cfg = OmegaConf.create(cfg.dump()) + if not torch.cuda.is_available(): + cfg.MODEL.DEVICE = "cpu" + # test that a model can be built with omegaconf config as well + build_model(cfg) diff --git a/data_processing/detectron2/tests/data/__init__.py b/data_processing/detectron2/tests/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tests/data/test_coco.py b/data_processing/detectron2/tests/data/test_coco.py new file mode 100644 index 0000000..caabead --- /dev/null +++ b/data_processing/detectron2/tests/data/test_coco.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import numpy as np +import os +import tempfile +import unittest +import pycocotools.mask as mask_util + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.coco import convert_to_coco_dict, load_coco_json +from detectron2.structures import BoxMode + + +def make_mask(): + """ + Makes a donut shaped binary mask. + """ + H = 100 + W = 100 + mask = np.zeros([H, W], dtype=np.uint8) + for x in range(W): + for y in range(H): + d = np.linalg.norm(np.array([W, H]) / 2 - np.array([x, y])) + if d > 10 and d < 20: + mask[y, x] = 1 + return mask + + +def uncompressed_rle(mask): + l = mask.flatten(order="F").tolist() + counts = [] + p = False + cnt = 0 + for i in l: + if i == p: + cnt += 1 + else: + counts.append(cnt) + p = i + cnt = 1 + counts.append(cnt) + return {"counts": counts, "size": [mask.shape[0], mask.shape[1]]} + + +def make_dataset_dicts(mask, compressed: bool = True): + """ + Returns a list of dicts that represents a single COCO data point for + object detection. The single instance given by `mask` is represented by + RLE, either compressed or uncompressed. + """ + record = {} + record["file_name"] = "test" + record["image_id"] = 0 + record["height"] = mask.shape[0] + record["width"] = mask.shape[1] + + y, x = np.nonzero(mask) + if compressed: + segmentation = mask_util.encode(np.asarray(mask, order="F")) + else: + segmentation = uncompressed_rle(mask) + min_x = np.min(x) + max_x = np.max(x) + min_y = np.min(y) + max_y = np.max(y) + obj = { + "bbox": [min_x, min_y, max_x, max_y], + "bbox_mode": BoxMode.XYXY_ABS, + "category_id": 0, + "iscrowd": 0, + "segmentation": segmentation, + } + record["annotations"] = [obj] + return [record] + + +class TestRLEToJson(unittest.TestCase): + def test(self): + # Make a dummy dataset. + mask = make_mask() + DatasetCatalog.register("test_dataset", lambda: make_dataset_dicts(mask)) + MetadataCatalog.get("test_dataset").set(thing_classes=["test_label"]) + + # Dump to json. + json_dict = convert_to_coco_dict("test_dataset") + with tempfile.TemporaryDirectory() as tmpdir: + json_file_name = os.path.join(tmpdir, "test.json") + with open(json_file_name, "w") as f: + json.dump(json_dict, f) + # Load from json. + dicts = load_coco_json(json_file_name, "") + + # Check the loaded mask matches the original. + anno = dicts[0]["annotations"][0] + loaded_mask = mask_util.decode(anno["segmentation"]) + self.assertTrue(np.array_equal(loaded_mask, mask)) + DatasetCatalog.pop("test_dataset") + MetadataCatalog.pop("test_dataset") + + def test_uncompressed_RLE(self): + mask = make_mask() + rle = mask_util.encode(np.asarray(mask, order="F")) + uncompressed = uncompressed_rle(mask) + compressed = mask_util.frPyObjects(uncompressed, *rle["size"]) + self.assertEqual(rle, compressed) + + +class TestConvertCOCO(unittest.TestCase): + @staticmethod + def generate_data(): + record = { + "file_name": "test", + "image_id": 0, + "height": 100, + "width": 100, + "annotations": [ + { + "bbox": [10, 10, 10, 10, 5], + "bbox_mode": BoxMode.XYWHA_ABS, + "category_id": 0, + "iscrowd": 0, + }, + { + "bbox": [15, 15, 3, 3], + "bbox_mode": BoxMode.XYXY_ABS, + "category_id": 0, + "iscrowd": 0, + }, + ], + } + + return [record] + + def test_convert_to_coco(self): + DatasetCatalog.register("test_dataset", lambda: TestConvertCOCO.generate_data()) + MetadataCatalog.get("test_dataset").set(thing_classes=["test_label"]) + convert_to_coco_dict("test_dataset") + DatasetCatalog.pop("test_dataset") + MetadataCatalog.pop("test_dataset") diff --git a/data_processing/detectron2/tests/data/test_coco_evaluation.py b/data_processing/detectron2/tests/data/test_coco_evaluation.py new file mode 100644 index 0000000..964f002 --- /dev/null +++ b/data_processing/detectron2/tests/data/test_coco_evaluation.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import io +import json +import numpy as np +import os +import tempfile +import unittest +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from detectron2.data import DatasetCatalog +from detectron2.evaluation import COCOEvaluator +from detectron2.evaluation.fast_eval_api import COCOeval_opt +from detectron2.structures import Boxes, Instances + + +class TestCOCOeval(unittest.TestCase): + def test_fast_eval(self): + # A small set of images/categories from COCO val + # fmt: off + detections = [{"image_id": 139, "category_id": 1, "bbox": [417.3332824707031, 159.27003479003906, 47.66064453125, 143.00193786621094], "score": 0.9949821829795837, "segmentation": {"size": [426, 640], "counts": "Tc`52W=3N0N4aNN^E7]:4XE1g:8kDMT;U100000001O1gE[Nk8h1dFiNY9Z1aFkN]9g2J3NdN`FlN`9S1cFRN07]9g1bFoM6;X9c1cFoM=8R9g1bFQN>3U9Y30O01OO1O001N2O1N1O4L4L5UNoE3V:CVF6Q:@YF9l9@ZF 0 else 0.0 + msg = "%s: comparing COCO APIs, %s differs by %f" % (name, k, abs_diff) + self.assertTrue(abs_diff < 1e-4, msg=msg) + + def test_unknown_category(self): + dataset = "coco_2017_val_100" + evaluator = COCOEvaluator(dataset) + evaluator.reset() + inputs = DatasetCatalog.get(dataset)[:2] + pred = Instances((100, 100)) + pred.pred_boxes = Boxes(torch.rand(2, 4)) + pred.scores = torch.rand(2) + pred.pred_classes = torch.tensor([10, 80]) + output = {"instances": pred} + evaluator.process(inputs, [output, output]) + with self.assertRaises(AssertionError): + evaluator.evaluate() diff --git a/data_processing/detectron2/tests/data/test_dataset.py b/data_processing/detectron2/tests/data/test_dataset.py new file mode 100644 index 0000000..7bdcda0 --- /dev/null +++ b/data_processing/detectron2/tests/data/test_dataset.py @@ -0,0 +1,185 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import pickle +import sys +import unittest +from functools import partial +import torch +from iopath.common.file_io import LazyPath + +from detectron2 import model_zoo +from detectron2.config import get_cfg, instantiate +from detectron2.data import ( + DatasetCatalog, + DatasetFromList, + MapDataset, + ToIterableDataset, + build_batch_data_loader, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.data.common import ( + AspectRatioGroupedDataset, + set_default_dataset_from_list_serialize_method, +) +from detectron2.data.samplers import InferenceSampler, TrainingSampler + + +def _a_slow_func(x): + return "path/{}".format(x) + + +class TestDatasetFromList(unittest.TestCase): + # Failing for py3.6, likely due to pickle + @unittest.skipIf(sys.version_info.minor <= 6, "Not supported in Python 3.6") + def test_using_lazy_path(self): + dataset = [] + for i in range(10): + dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))}) + + dataset = DatasetFromList(dataset) + for i in range(10): + path = dataset[i]["file_name"] + self.assertTrue(isinstance(path, LazyPath)) + self.assertEqual(os.fspath(path), _a_slow_func(i)) + + def test_alternative_serialize_method(self): + dataset = [1, 2, 3] + dataset = DatasetFromList(dataset, serialize=torch.tensor) + self.assertEqual(dataset[2], torch.tensor(3)) + + def test_change_default_serialize_method(self): + dataset = [1, 2, 3] + with set_default_dataset_from_list_serialize_method(torch.tensor): + dataset_1 = DatasetFromList(dataset, serialize=True) + self.assertEqual(dataset_1[2], torch.tensor(3)) + dataset_2 = DatasetFromList(dataset, serialize=True) + self.assertEqual(dataset_2[2], 3) + + +class TestMapDataset(unittest.TestCase): + @staticmethod + def map_func(x): + if x == 2: + return None + return x * 2 + + def test_map_style(self): + ds = DatasetFromList([1, 2, 3]) + ds = MapDataset(ds, TestMapDataset.map_func) + self.assertEqual(ds[0], 2) + self.assertEqual(ds[2], 6) + self.assertIn(ds[1], [2, 6]) + + def test_iter_style(self): + class DS(torch.utils.data.IterableDataset): + def __iter__(self): + yield from [1, 2, 3] + + ds = DS() + ds = MapDataset(ds, TestMapDataset.map_func) + self.assertIsInstance(ds, torch.utils.data.IterableDataset) + + data = list(iter(ds)) + self.assertEqual(data, [2, 6]) + + def test_pickleability(self): + ds = DatasetFromList([1, 2, 3]) + ds = MapDataset(ds, lambda x: x * 2) + ds = pickle.loads(pickle.dumps(ds)) + self.assertEqual(ds[0], 2) + + +class TestAspectRatioGrouping(unittest.TestCase): + def test_reiter_leak(self): + data = [(1, 0), (0, 1), (1, 0), (0, 1)] + data = [{"width": a, "height": b} for (a, b) in data] + batchsize = 2 + dataset = AspectRatioGroupedDataset(data, batchsize) + + for _ in range(5): + for idx, __ in enumerate(dataset): + if idx == 1: + # manually break, so the iterator does not stop by itself + break + # check that bucket sizes are valid + for bucket in dataset._buckets: + self.assertLess(len(bucket), batchsize) + + +class _MyData(torch.utils.data.IterableDataset): + def __iter__(self): + while True: + yield 1 + + +class TestDataLoader(unittest.TestCase): + def _get_kwargs(self): + # get kwargs of build_detection_train_loader + cfg = model_zoo.get_config("common/data/coco.py").dataloader.train + cfg.dataset.names = "coco_2017_val_100" + cfg.pop("_target_") + kwargs = {k: instantiate(v) for k, v in cfg.items()} + return kwargs + + def test_build_dataloader_train(self): + kwargs = self._get_kwargs() + dl = build_detection_train_loader(**kwargs) + next(iter(dl)) + + def test_build_iterable_dataloader_train(self): + kwargs = self._get_kwargs() + ds = DatasetFromList(kwargs.pop("dataset")) + ds = ToIterableDataset(ds, TrainingSampler(len(ds))) + dl = build_detection_train_loader(dataset=ds, **kwargs) + next(iter(dl)) + + def test_build_iterable_dataloader_from_cfg(self): + cfg = get_cfg() + cfg.DATASETS.TRAIN = ["iter_data"] + DatasetCatalog.register("iter_data", lambda: _MyData()) + dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False) + next(iter(dl)) + + dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x) + next(iter(dl)) + + def _check_is_range(self, data_loader, N): + # check that data_loader produces range(N) + data = list(iter(data_loader)) + data = [x for batch in data for x in batch] # flatten the batches + self.assertEqual(len(data), N) + self.assertEqual(set(data), set(range(N))) + + def test_build_batch_dataloader_inference(self): + # Test that build_batch_data_loader can be used for inference + N = 96 + ds = DatasetFromList(list(range(N))) + sampler = InferenceSampler(len(ds)) + dl = build_batch_data_loader(ds, sampler, 8, num_workers=3) + self._check_is_range(dl, N) + + def test_build_dataloader_inference(self): + N = 50 + ds = DatasetFromList(list(range(N))) + sampler = InferenceSampler(len(ds)) + # test that parallel loader works correctly + dl = build_detection_test_loader( + dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3 + ) + self._check_is_range(dl, N) + + # test that batch_size works correctly + dl = build_detection_test_loader( + dataset=ds, sampler=sampler, mapper=lambda x: x, batch_size=4, num_workers=0 + ) + self._check_is_range(dl, N) + + def test_build_iterable_dataloader_inference(self): + # Test that build_detection_test_loader supports iterable dataset + N = 50 + ds = DatasetFromList(list(range(N))) + ds = ToIterableDataset(ds, InferenceSampler(len(ds))) + dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3) + self._check_is_range(dl, N) diff --git a/data_processing/detectron2/tests/data/test_detection_utils.py b/data_processing/detectron2/tests/data/test_detection_utils.py new file mode 100644 index 0000000..aac56c0 --- /dev/null +++ b/data_processing/detectron2/tests/data/test_detection_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import numpy as np +import os +import unittest +import pycocotools.mask as mask_util + +from detectron2.data import MetadataCatalog, detection_utils +from detectron2.data import transforms as T +from detectron2.structures import BitMasks, BoxMode +from detectron2.utils.file_io import PathManager + + +class TestTransformAnnotations(unittest.TestCase): + def test_transform_simple_annotation(self): + transforms = T.TransformList([T.HFlipTransform(400)]) + anno = { + "bbox": np.asarray([10, 10, 200, 300]), + "bbox_mode": BoxMode.XYXY_ABS, + "category_id": 3, + "segmentation": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]], + } + + output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400)) + self.assertTrue(np.allclose(output["bbox"], [200, 10, 390, 300])) + self.assertEqual(len(output["segmentation"]), len(anno["segmentation"])) + self.assertTrue(np.allclose(output["segmentation"][0], [390, 10, 300, 100, 300, 10])) + + detection_utils.annotations_to_instances([output, output], (400, 400)) + + def test_transform_empty_annotation(self): + detection_utils.annotations_to_instances([], (400, 400)) + + def test_flip_keypoints(self): + transforms = T.TransformList([T.HFlipTransform(400)]) + anno = { + "bbox": np.asarray([10, 10, 200, 300]), + "bbox_mode": BoxMode.XYXY_ABS, + "keypoints": np.random.rand(17, 3) * 50 + 15, + } + + output = detection_utils.transform_instance_annotations( + copy.deepcopy(anno), + transforms, + (400, 400), + keypoint_hflip_indices=detection_utils.create_keypoint_hflip_indices( + ["keypoints_coco_2017_train"] + ), + ) + # The first keypoint is nose + self.assertTrue(np.allclose(output["keypoints"][0, 0], 400 - anno["keypoints"][0, 0])) + # The last 16 keypoints are 8 left-right pairs + self.assertTrue( + np.allclose( + output["keypoints"][1:, 0].reshape(-1, 2)[:, ::-1], + 400 - anno["keypoints"][1:, 0].reshape(-1, 2), + ) + ) + self.assertTrue( + np.allclose( + output["keypoints"][1:, 1:].reshape(-1, 2, 2)[:, ::-1, :], + anno["keypoints"][1:, 1:].reshape(-1, 2, 2), + ) + ) + + def test_crop(self): + transforms = T.TransformList([T.CropTransform(300, 300, 10, 10)]) + keypoints = np.random.rand(17, 3) * 50 + 15 + keypoints[:, 2] = 2 + anno = { + "bbox": np.asarray([10, 10, 200, 400]), + "bbox_mode": BoxMode.XYXY_ABS, + "keypoints": keypoints, + } + + output = detection_utils.transform_instance_annotations( + copy.deepcopy(anno), transforms, (10, 10) + ) + # box is shifted and cropped + self.assertTrue((output["bbox"] == np.asarray([0, 0, 0, 10])).all()) + # keypoints are no longer visible + self.assertTrue((output["keypoints"][:, 2] == 0).all()) + + def test_transform_RLE(self): + transforms = T.TransformList([T.HFlipTransform(400)]) + mask = np.zeros((300, 400), order="F").astype("uint8") + mask[:, :200] = 1 + + anno = { + "bbox": np.asarray([10, 10, 200, 300]), + "bbox_mode": BoxMode.XYXY_ABS, + "segmentation": mask_util.encode(mask[:, :, None])[0], + "category_id": 3, + } + output = detection_utils.transform_instance_annotations( + copy.deepcopy(anno), transforms, (300, 400) + ) + mask = output["segmentation"] + self.assertTrue((mask[:, 200:] == 1).all()) + self.assertTrue((mask[:, :200] == 0).all()) + + inst = detection_utils.annotations_to_instances( + [output, output], (400, 400), mask_format="bitmask" + ) + self.assertTrue(isinstance(inst.gt_masks, BitMasks)) + + def test_transform_RLE_resize(self): + transforms = T.TransformList( + [T.HFlipTransform(400), T.ScaleTransform(300, 400, 400, 400, "bilinear")] + ) + mask = np.zeros((300, 400), order="F").astype("uint8") + mask[:, :200] = 1 + + anno = { + "bbox": np.asarray([10, 10, 200, 300]), + "bbox_mode": BoxMode.XYXY_ABS, + "segmentation": mask_util.encode(mask[:, :, None])[0], + "category_id": 3, + } + output = detection_utils.transform_instance_annotations( + copy.deepcopy(anno), transforms, (400, 400) + ) + + inst = detection_utils.annotations_to_instances( + [output, output], (400, 400), mask_format="bitmask" + ) + self.assertTrue(isinstance(inst.gt_masks, BitMasks)) + + def test_gen_crop(self): + instance = {"bbox": [10, 10, 100, 100], "bbox_mode": BoxMode.XYXY_ABS} + t = detection_utils.gen_crop_transform_with_instance((10, 10), (150, 150), instance) + # the box center must fall into the cropped region + self.assertTrue(t.x0 <= 55 <= t.x0 + t.w) + + def test_gen_crop_outside_boxes(self): + instance = {"bbox": [10, 10, 100, 100], "bbox_mode": BoxMode.XYXY_ABS} + with self.assertRaises(AssertionError): + detection_utils.gen_crop_transform_with_instance((10, 10), (15, 15), instance) + + def test_read_sem_seg(self): + cityscapes_dir = MetadataCatalog.get("cityscapes_fine_sem_seg_val").gt_dir + sem_seg_gt_path = os.path.join( + cityscapes_dir, "frankfurt", "frankfurt_000001_083852_gtFine_labelIds.png" + ) + if not PathManager.exists(sem_seg_gt_path): + raise unittest.SkipTest( + "Semantic segmentation ground truth {} not found.".format(sem_seg_gt_path) + ) + sem_seg = detection_utils.read_image(sem_seg_gt_path, "L") + self.assertEqual(sem_seg.ndim, 3) + self.assertEqual(sem_seg.shape[2], 1) + self.assertEqual(sem_seg.dtype, np.uint8) + self.assertEqual(sem_seg.max(), 32) + self.assertEqual(sem_seg.min(), 1) + + def test_read_exif_orientation(self): + # https://github.com/recurser/exif-orientation-examples/raw/master/Landscape_5.jpg + URL = "detectron2://assets/Landscape_5.jpg" + img = detection_utils.read_image(URL, "RGB") + self.assertEqual(img.ndim, 3) + self.assertEqual(img.dtype, np.uint8) + self.assertEqual(img.shape, (1200, 1800, 3)) # check that shape is not transposed + + def test_opencv_exif_orientation(self): + import cv2 + + URL = "detectron2://assets/Landscape_5.jpg" + with PathManager.open(URL, "rb") as f: + img = cv2.imdecode(np.frombuffer(f.read(), dtype="uint8"), cv2.IMREAD_COLOR) + self.assertEqual(img.dtype, np.uint8) + self.assertEqual(img.shape, (1200, 1800, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/data/test_rotation_transform.py b/data_processing/detectron2/tests/data/test_rotation_transform.py new file mode 100644 index 0000000..0e8299e --- /dev/null +++ b/data_processing/detectron2/tests/data/test_rotation_transform.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest + +from detectron2.data.transforms.transform import RotationTransform + + +class TestRotationTransform(unittest.TestCase): + def assertEqualsArrays(self, a1, a2): + self.assertTrue(np.allclose(a1, a2)) + + def randomData(self, h=5, w=5): + image = np.random.rand(h, w) + coords = np.array([[i, j] for j in range(h + 1) for i in range(w + 1)], dtype=float) + return image, coords, h, w + + def test180(self): + image, coords, h, w = self.randomData(6, 6) + rot = RotationTransform(h, w, 180, expand=False, center=None) + self.assertEqualsArrays(rot.apply_image(image), image[::-1, ::-1]) + rotated_coords = [[w - c[0], h - c[1]] for c in coords] + self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) + + def test45_coords(self): + _, coords, h, w = self.randomData(4, 6) + rot = RotationTransform(h, w, 45, expand=False, center=None) + rotated_coords = [ + [(x + y - (h + w) / 2) / np.sqrt(2) + w / 2, h / 2 + (y + (w - h) / 2 - x) / np.sqrt(2)] + for (x, y) in coords + ] + self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) + + def test90(self): + image, coords, h, w = self.randomData() + rot = RotationTransform(h, w, 90, expand=False, center=None) + self.assertEqualsArrays(rot.apply_image(image), image.T[::-1]) + rotated_coords = [[c[1], w - c[0]] for c in coords] + self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) + + def test90_expand(self): # non-square image + image, coords, h, w = self.randomData(h=5, w=8) + rot = RotationTransform(h, w, 90, expand=True, center=None) + self.assertEqualsArrays(rot.apply_image(image), image.T[::-1]) + rotated_coords = [[c[1], w - c[0]] for c in coords] + self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) + + def test_center_expand(self): + # center has no effect if expand=True because it only affects shifting + image, coords, h, w = self.randomData(h=5, w=8) + angle = np.random.randint(360) + rot1 = RotationTransform(h, w, angle, expand=True, center=None) + rot2 = RotationTransform(h, w, angle, expand=True, center=(0, 0)) + rot3 = RotationTransform(h, w, angle, expand=True, center=(h, w)) + rot4 = RotationTransform(h, w, angle, expand=True, center=(2, 5)) + for r1 in [rot1, rot2, rot3, rot4]: + for r2 in [rot1, rot2, rot3, rot4]: + self.assertEqualsArrays(r1.apply_image(image), r2.apply_image(image)) + self.assertEqualsArrays(r1.apply_coords(coords), r2.apply_coords(coords)) + + def test_inverse_transform(self): + image, coords, h, w = self.randomData(h=5, w=8) + rot = RotationTransform(h, w, 90, expand=True, center=None) + rot_image = rot.apply_image(image) + self.assertEqualsArrays(rot.inverse().apply_image(rot_image), image) + rot = RotationTransform(h, w, 65, expand=True, center=None) + rotated_coords = rot.apply_coords(coords) + self.assertEqualsArrays(rot.inverse().apply_coords(rotated_coords), coords) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/data/test_sampler.py b/data_processing/detectron2/tests/data/test_sampler.py new file mode 100644 index 0000000..0d27843 --- /dev/null +++ b/data_processing/detectron2/tests/data/test_sampler.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import math +import operator +import unittest +import torch +from torch.utils import data +from torch.utils.data.sampler import SequentialSampler + +from detectron2.data.build import worker_init_reset_seed +from detectron2.data.common import DatasetFromList, ToIterableDataset +from detectron2.data.samplers import ( + GroupedBatchSampler, + InferenceSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) +from detectron2.utils.env import seed_all_rng + + +class TestGroupedBatchSampler(unittest.TestCase): + def test_missing_group_id(self): + sampler = SequentialSampler(list(range(100))) + group_ids = [1] * 100 + samples = GroupedBatchSampler(sampler, group_ids, 2) + + for mini_batch in samples: + self.assertEqual(len(mini_batch), 2) + + def test_groups(self): + sampler = SequentialSampler(list(range(100))) + group_ids = [1, 0] * 50 + samples = GroupedBatchSampler(sampler, group_ids, 2) + + for mini_batch in samples: + self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0) + + +class TestSamplerDeterministic(unittest.TestCase): + def test_to_iterable(self): + sampler = TrainingSampler(100, seed=10) + gt_output = list(itertools.islice(sampler, 100)) + self.assertEqual(set(gt_output), set(range(100))) + + dataset = DatasetFromList(list(range(100))) + dataset = ToIterableDataset(dataset, sampler) + data_loader = data.DataLoader(dataset, num_workers=0, collate_fn=operator.itemgetter(0)) + + output = list(itertools.islice(data_loader, 100)) + self.assertEqual(output, gt_output) + + data_loader = data.DataLoader( + dataset, + num_workers=2, + collate_fn=operator.itemgetter(0), + worker_init_fn=worker_init_reset_seed, + # reset seed should not affect behavior of TrainingSampler + ) + output = list(itertools.islice(data_loader, 100)) + # multiple workers should not lead to duplicate or different data + self.assertEqual(output, gt_output) + + def test_training_sampler_seed(self): + seed_all_rng(42) + sampler = TrainingSampler(30) + data = list(itertools.islice(sampler, 65)) + + seed_all_rng(42) + sampler = TrainingSampler(30) + seed_all_rng(999) # should be ineffective + data2 = list(itertools.islice(sampler, 65)) + self.assertEqual(data, data2) + + +class TestRepeatFactorTrainingSampler(unittest.TestCase): + def test_repeat_factors_from_category_frequency(self): + repeat_thresh = 0.5 + + dataset_dicts = [ + {"annotations": [{"category_id": 0}, {"category_id": 1}]}, + {"annotations": [{"category_id": 0}]}, + {"annotations": []}, + ] + + rep_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dicts, repeat_thresh + ) + + expected_rep_factors = torch.tensor([math.sqrt(3 / 2), 1.0, 1.0]) + self.assertTrue(torch.allclose(rep_factors, expected_rep_factors)) + + +class TestInferenceSampler(unittest.TestCase): + def test_local_indices(self): + sizes = [0, 16, 2, 42] + world_sizes = [5, 2, 3, 4] + + expected_results = [ + [range(0) for _ in range(5)], + [range(8), range(8, 16)], + [range(1), range(1, 2), range(0)], + [range(11), range(11, 22), range(22, 32), range(32, 42)], + ] + + for size, world_size, expected_result in zip(sizes, world_sizes, expected_results): + with self.subTest(f"size={size}, world_size={world_size}"): + local_indices = [ + InferenceSampler._get_local_indices(size, world_size, r) + for r in range(world_size) + ] + self.assertEqual(local_indices, expected_result) diff --git a/data_processing/detectron2/tests/data/test_transforms.py b/data_processing/detectron2/tests/data/test_transforms.py new file mode 100644 index 0000000..382048e --- /dev/null +++ b/data_processing/detectron2/tests/data/test_transforms.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +import unittest +from unittest import mock +import torch +from PIL import Image, ImageOps +from torch.nn import functional as F + +from detectron2.config import get_cfg +from detectron2.data import detection_utils +from detectron2.data import transforms as T +from detectron2.utils.logger import setup_logger + +logger = logging.getLogger(__name__) + + +def polygon_allclose(poly1, poly2): + """ + Test whether two polygons are the same. + Both arguments are nx2 numpy arrays. + """ + # ABCD and CDAB are the same polygon. So it's important to check after rolling + for k in range(len(poly1)): + rolled_poly1 = np.roll(poly1, k, axis=0) + if np.allclose(rolled_poly1, poly2): + return True + return False + + +class TestTransforms(unittest.TestCase): + def setUp(self): + setup_logger() + + def test_apply_rotated_boxes(self): + np.random.seed(125) + cfg = get_cfg() + is_train = True + augs = detection_utils.build_augmentation(cfg, is_train) + image = np.random.rand(200, 300) + image, transforms = T.apply_augmentations(augs, image) + image_shape = image.shape[:2] # h, w + assert image_shape == (800, 1200) + annotation = {"bbox": [179, 97, 62, 40, -56]} + + boxes = np.array([annotation["bbox"]], dtype=np.float64) # boxes.shape = (1, 5) + transformed_bbox = transforms.apply_rotated_box(boxes)[0] + + expected_bbox = np.array([484, 388, 248, 160, 56], dtype=np.float64) + err_msg = "transformed_bbox = {}, expected {}".format(transformed_bbox, expected_bbox) + assert np.allclose(transformed_bbox, expected_bbox), err_msg + + def test_resize_and_crop(self): + np.random.seed(125) + min_scale = 0.2 + max_scale = 2.0 + target_height = 1100 + target_width = 1000 + resize_aug = T.ResizeScale(min_scale, max_scale, target_height, target_width) + fixed_size_crop_aug = T.FixedSizeCrop((target_height, target_width)) + hflip_aug = T.RandomFlip() + augs = [resize_aug, fixed_size_crop_aug, hflip_aug] + original_image = np.random.rand(900, 800) + image, transforms = T.apply_augmentations(augs, original_image) + image_shape = image.shape[:2] # h, w + self.assertEqual((1100, 1000), image_shape) + + boxes = np.array( + [[91, 46, 144, 111], [523, 251, 614, 295]], + dtype=np.float64, + ) + transformed_bboxs = transforms.apply_box(boxes) + expected_bboxs = np.array( + [ + [895.42, 33.42666667, 933.91125, 80.66], + [554.0825, 182.39333333, 620.17125, 214.36666667], + ], + dtype=np.float64, + ) + err_msg = "transformed_bbox = {}, expected {}".format(transformed_bboxs, expected_bboxs) + self.assertTrue(np.allclose(transformed_bboxs, expected_bboxs), err_msg) + + polygon = np.array([[91, 46], [144, 46], [144, 111], [91, 111]]) + transformed_polygons = transforms.apply_polygons([polygon]) + expected_polygon = np.array([[934.0, 33.0], [934.0, 80.0], [896.0, 80.0], [896.0, 33.0]]) + self.assertEqual(1, len(transformed_polygons)) + err_msg = "transformed_polygon = {}, expected {}".format( + transformed_polygons[0], expected_polygon + ) + self.assertTrue(polygon_allclose(transformed_polygons[0], expected_polygon), err_msg) + + def test_apply_rotated_boxes_unequal_scaling_factor(self): + np.random.seed(125) + h, w = 400, 200 + newh, neww = 800, 800 + image = np.random.rand(h, w) + augs = [] + augs.append(T.Resize(shape=(newh, neww))) + image, transforms = T.apply_augmentations(augs, image) + image_shape = image.shape[:2] # h, w + assert image_shape == (newh, neww) + + boxes = np.array( + [ + [150, 100, 40, 20, 0], + [150, 100, 40, 20, 30], + [150, 100, 40, 20, 90], + [150, 100, 40, 20, -90], + ], + dtype=np.float64, + ) + transformed_boxes = transforms.apply_rotated_box(boxes) + + expected_bboxes = np.array( + [ + [600, 200, 160, 40, 0], + [600, 200, 144.22205102, 52.91502622, 49.10660535], + [600, 200, 80, 80, 90], + [600, 200, 80, 80, -90], + ], + dtype=np.float64, + ) + err_msg = "transformed_boxes = {}, expected {}".format(transformed_boxes, expected_bboxes) + assert np.allclose(transformed_boxes, expected_bboxes), err_msg + + def test_print_augmentation(self): + t = T.RandomCrop("relative", (100, 100)) + self.assertEqual(str(t), "RandomCrop(crop_type='relative', crop_size=(100, 100))") + + t0 = T.RandomFlip(prob=0.5) + self.assertEqual(str(t0), "RandomFlip(prob=0.5)") + + t1 = T.RandomFlip() + self.assertEqual(str(t1), "RandomFlip()") + + t = T.AugmentationList([t0, t1]) + self.assertEqual(str(t), f"AugmentationList[{t0}, {t1}]") + + def test_random_apply_prob_out_of_range_check(self): + test_probabilities = {0.0: True, 0.5: True, 1.0: True, -0.01: False, 1.01: False} + + for given_probability, is_valid in test_probabilities.items(): + if not is_valid: + self.assertRaises(AssertionError, T.RandomApply, None, prob=given_probability) + else: + T.RandomApply(T.NoOpTransform(), prob=given_probability) + + def test_random_apply_wrapping_aug_probability_occured_evaluation(self): + transform_mock = mock.MagicMock(name="MockTransform", spec=T.Augmentation) + image_mock = mock.MagicMock(name="MockImage") + random_apply = T.RandomApply(transform_mock, prob=0.001) + + with mock.patch.object(random_apply, "_rand_range", return_value=0.0001): + transform = random_apply.get_transform(image_mock) + transform_mock.get_transform.assert_called_once_with(image_mock) + self.assertIsNot(transform, transform_mock) + + def test_random_apply_wrapping_std_transform_probability_occured_evaluation(self): + transform_mock = mock.MagicMock(name="MockTransform", spec=T.Transform) + image_mock = mock.MagicMock(name="MockImage") + random_apply = T.RandomApply(transform_mock, prob=0.001) + + with mock.patch.object(random_apply, "_rand_range", return_value=0.0001): + transform = random_apply.get_transform(image_mock) + self.assertIs(transform, transform_mock) + + def test_random_apply_probability_not_occured_evaluation(self): + transform_mock = mock.MagicMock(name="MockTransform", spec=T.Augmentation) + image_mock = mock.MagicMock(name="MockImage") + random_apply = T.RandomApply(transform_mock, prob=0.001) + + with mock.patch.object(random_apply, "_rand_range", return_value=0.9): + transform = random_apply.get_transform(image_mock) + transform_mock.get_transform.assert_not_called() + self.assertIsInstance(transform, T.NoOpTransform) + + def test_augmentation_input_args(self): + input_shape = (100, 100) + output_shape = (50, 50) + + # define two augmentations with different args + class TG1(T.Augmentation): + def get_transform(self, image, sem_seg): + return T.ResizeTransform( + input_shape[0], input_shape[1], output_shape[0], output_shape[1] + ) + + class TG2(T.Augmentation): + def get_transform(self, image): + assert image.shape[:2] == output_shape # check that TG1 is applied + return T.HFlipTransform(output_shape[1]) + + image = np.random.rand(*input_shape).astype("float32") + sem_seg = (np.random.rand(*input_shape) < 0.5).astype("uint8") + inputs = T.AugInput(image, sem_seg=sem_seg) # provide two args + tfms = inputs.apply_augmentations([TG1(), TG2()]) + self.assertIsInstance(tfms[0], T.ResizeTransform) + self.assertIsInstance(tfms[1], T.HFlipTransform) + self.assertTrue(inputs.image.shape[:2] == output_shape) + self.assertTrue(inputs.sem_seg.shape[:2] == output_shape) + + class TG3(T.Augmentation): + def get_transform(self, image, nonexist): + pass + + with self.assertRaises(AttributeError): + inputs.apply_augmentations([TG3()]) + + def test_augmentation_list(self): + input_shape = (100, 100) + image = np.random.rand(*input_shape).astype("float32") + sem_seg = (np.random.rand(*input_shape) < 0.5).astype("uint8") + inputs = T.AugInput(image, sem_seg=sem_seg) # provide two args + + augs = T.AugmentationList([T.RandomFlip(), T.Resize(20)]) + _ = T.AugmentationList([augs, T.Resize(30)])(inputs) + # 3 in latest fvcore (flattened transformlist), 2 in older + # self.assertEqual(len(tfms), 3) + + def test_color_transforms(self): + rand_img = np.random.random((100, 100, 3)) * 255 + rand_img = rand_img.astype("uint8") + + # Test no-op + noop_transform = T.ColorTransform(lambda img: img) + self.assertTrue(np.array_equal(rand_img, noop_transform.apply_image(rand_img))) + + # Test a ImageOps operation + magnitude = np.random.randint(0, 256) + solarize_transform = T.PILColorTransform(lambda img: ImageOps.solarize(img, magnitude)) + expected_img = ImageOps.solarize(Image.fromarray(rand_img), magnitude) + self.assertTrue(np.array_equal(expected_img, solarize_transform.apply_image(rand_img))) + + def test_resize_transform(self): + input_shapes = [(100, 100), (100, 100, 1), (100, 100, 3)] + output_shapes = [(200, 200), (200, 200, 1), (200, 200, 3)] + for in_shape, out_shape in zip(input_shapes, output_shapes): + in_img = np.random.randint(0, 255, size=in_shape, dtype=np.uint8) + tfm = T.ResizeTransform(in_shape[0], in_shape[1], out_shape[0], out_shape[1]) + out_img = tfm.apply_image(in_img) + self.assertEqual(out_img.shape, out_shape) + + def test_resize_shorted_edge_scriptable(self): + def f(image): + newh, neww = T.ResizeShortestEdge.get_output_shape( + image.shape[-2], image.shape[-1], 80, 133 + ) + return F.interpolate(image.unsqueeze(0), size=(newh, neww)) + + input = torch.randn(3, 10, 10) + script_f = torch.jit.script(f) + self.assertTrue(torch.allclose(f(input), script_f(input))) + + # generalize to new shapes + input = torch.randn(3, 8, 100) + self.assertTrue(torch.allclose(f(input), script_f(input))) + + def test_extent_transform(self): + input_shapes = [(100, 100), (100, 100, 1), (100, 100, 3)] + src_rect = (20, 20, 80, 80) + output_shapes = [(200, 200), (200, 200, 1), (200, 200, 3)] + for in_shape, out_shape in zip(input_shapes, output_shapes): + in_img = np.random.randint(0, 255, size=in_shape, dtype=np.uint8) + tfm = T.ExtentTransform(src_rect, out_shape[:2]) + out_img = tfm.apply_image(in_img) + self.assertTrue(out_img.shape == out_shape) diff --git a/data_processing/detectron2/tests/export/test_c10.py b/data_processing/detectron2/tests/export/test_c10.py new file mode 100644 index 0000000..55076ab --- /dev/null +++ b/data_processing/detectron2/tests/export/test_c10.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest + +try: + # Caffe2 used to be included in PyTorch, but since PyTorch 1.10+, + # it is not included in pre-built packages. This is a safety BC check + from detectron2.config import get_cfg + from detectron2.export.c10 import Caffe2RPN + from detectron2.layers import ShapeSpec +except ImportError: + raise unittest.SkipTest( + f"PyTorch does not have Caffe2 support. Skipping all tests in {__name__}" + ) from None + + +class TestCaffe2RPN(unittest.TestCase): + def test_instantiation(self): + cfg = get_cfg() + cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (1, 1, 1, 1, 1) + input_shapes = {"res4": ShapeSpec(channels=256, stride=4)} + rpn = Caffe2RPN(cfg, input_shapes) + assert rpn is not None + cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (10, 10, 5, 5, 1) + with self.assertRaises(AssertionError): + rpn = Caffe2RPN(cfg, input_shapes) diff --git a/data_processing/detectron2/tests/layers/__init__.py b/data_processing/detectron2/tests/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tests/layers/test_blocks.py b/data_processing/detectron2/tests/layers/test_blocks.py new file mode 100644 index 0000000..5a0488a --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_blocks.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +import torch +from torch import nn + +from detectron2.layers import ASPP, DepthwiseSeparableConv2d, FrozenBatchNorm2d +from detectron2.modeling.backbone.resnet import BasicStem, ResNet + + +""" +Test for misc layers. +""" + + +class TestBlocks(unittest.TestCase): + def test_separable_conv(self): + DepthwiseSeparableConv2d(3, 10, norm1="BN", activation1=nn.PReLU()) + + def test_aspp(self): + m = ASPP(3, 10, [2, 3, 4], norm="", activation=nn.PReLU()) + self.assertIsNot(m.convs[0].activation.weight, m.convs[1].activation.weight) + self.assertIsNot(m.convs[0].activation.weight, m.project.activation.weight) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_frozen_batchnorm_fp16(self): + from torch.cuda.amp import autocast + + C = 10 + input = torch.rand(1, C, 10, 10).cuda() + m = FrozenBatchNorm2d(C).cuda() + with autocast(): + output = m(input.half()) + self.assertEqual(output.dtype, torch.float16) + + # requires_grad triggers a different codepath + input.requires_grad_() + with autocast(): + output = m(input.half()) + self.assertEqual(output.dtype, torch.float16) + + def test_resnet_unused_stages(self): + resnet = ResNet(BasicStem(), ResNet.make_default_stages(18), out_features=["res2"]) + self.assertTrue(hasattr(resnet, "res2")) + self.assertFalse(hasattr(resnet, "res3")) + self.assertFalse(hasattr(resnet, "res5")) + + resnet = ResNet(BasicStem(), ResNet.make_default_stages(18), out_features=["res2", "res5"]) + self.assertTrue(hasattr(resnet, "res2")) + self.assertTrue(hasattr(resnet, "res4")) + self.assertTrue(hasattr(resnet, "res5")) diff --git a/data_processing/detectron2/tests/layers/test_deformable.py b/data_processing/detectron2/tests/layers/test_deformable.py new file mode 100644 index 0000000..4aa319f --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_deformable.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest +import torch + +from detectron2.layers import DeformConv, ModulatedDeformConv +from detectron2.utils.env import TORCH_VERSION + + +@unittest.skipIf( + TORCH_VERSION == (1, 8) and torch.cuda.is_available(), + "This test fails under cuda11 + torch1.8.", +) +class DeformableTest(unittest.TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu") + def test_forward_output(self): + device = torch.device("cuda") + N, C, H, W = shape = 1, 1, 5, 5 + kernel_size = 3 + padding = 1 + + inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape).to(device) + """ + 0 1 2 3 4 + 5 6 7 8 9 + 10 11 12 13 14 + 15 16 17 18 19 + 20 21 22 23 24 + """ + offset_channels = kernel_size * kernel_size * 2 + offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32).to(device) + + # Test DCN v1 + deform = DeformConv(C, C, kernel_size=kernel_size, padding=padding).to(device) + deform.weight = torch.nn.Parameter(torch.ones_like(deform.weight)) + output = deform(inputs, offset) + output = output.detach().cpu().numpy() + deform_results = np.array( + [ + [30, 41.25, 48.75, 45, 28.75], + [62.25, 81, 90, 80.25, 50.25], + [99.75, 126, 135, 117.75, 72.75], + [105, 131.25, 138.75, 120, 73.75], + [71.75, 89.25, 93.75, 80.75, 49.5], + ] + ) + self.assertTrue(np.allclose(output.flatten(), deform_results.flatten())) + + # Test DCN v2 + mask_channels = kernel_size * kernel_size + mask = torch.full((N, mask_channels, H, W), 0.5, dtype=torch.float32).to(device) + modulate_deform = ModulatedDeformConv(C, C, kernel_size, padding=padding, bias=False).to( + device + ) + modulate_deform.weight = deform.weight + output = modulate_deform(inputs, offset, mask) + output = output.detach().cpu().numpy() + self.assertTrue(np.allclose(output.flatten(), deform_results.flatten() * 0.5)) + + def test_forward_output_on_cpu(self): + device = torch.device("cpu") + N, C, H, W = shape = 1, 1, 5, 5 + kernel_size = 3 + padding = 1 + + inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape).to(device) + + offset_channels = kernel_size * kernel_size * 2 + offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32).to(device) + + # Test DCN v1 on cpu + deform = DeformConv(C, C, kernel_size=kernel_size, padding=padding).to(device) + deform.weight = torch.nn.Parameter(torch.ones_like(deform.weight)) + output = deform(inputs, offset) + output = output.detach().cpu().numpy() + deform_results = np.array( + [ + [30, 41.25, 48.75, 45, 28.75], + [62.25, 81, 90, 80.25, 50.25], + [99.75, 126, 135, 117.75, 72.75], + [105, 131.25, 138.75, 120, 73.75], + [71.75, 89.25, 93.75, 80.75, 49.5], + ] + ) + self.assertTrue(np.allclose(output.flatten(), deform_results.flatten())) + + @unittest.skipIf(not torch.cuda.is_available(), "This test requires gpu access") + def test_forward_output_on_cpu_equals_output_on_gpu(self): + N, C, H, W = shape = 2, 4, 10, 10 + kernel_size = 3 + padding = 1 + + for groups in [1, 2]: + inputs = torch.arange(np.prod(shape), dtype=torch.float32).reshape(*shape) + offset_channels = kernel_size * kernel_size * 2 + offset = torch.full((N, offset_channels, H, W), 0.5, dtype=torch.float32) + + deform_gpu = DeformConv( + C, C, kernel_size=kernel_size, padding=padding, groups=groups + ).to("cuda") + deform_gpu.weight = torch.nn.Parameter(torch.ones_like(deform_gpu.weight)) + output_gpu = deform_gpu(inputs.to("cuda"), offset.to("cuda")).detach().cpu().numpy() + + deform_cpu = DeformConv( + C, C, kernel_size=kernel_size, padding=padding, groups=groups + ).to("cpu") + deform_cpu.weight = torch.nn.Parameter(torch.ones_like(deform_cpu.weight)) + output_cpu = deform_cpu(inputs.to("cpu"), offset.to("cpu")).detach().numpy() + + self.assertTrue(np.allclose(output_gpu.flatten(), output_cpu.flatten())) + + @unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu") + def test_small_input(self): + device = torch.device("cuda") + for kernel_size in [3, 5]: + padding = kernel_size // 2 + N, C, H, W = shape = (1, 1, kernel_size - 1, kernel_size - 1) + + inputs = torch.rand(shape).to(device) # input size is smaller than kernel size + + offset_channels = kernel_size * kernel_size * 2 + offset = torch.randn((N, offset_channels, H, W), dtype=torch.float32).to(device) + deform = DeformConv(C, C, kernel_size=kernel_size, padding=padding).to(device) + output = deform(inputs, offset) + self.assertTrue(output.shape == inputs.shape) + + mask_channels = kernel_size * kernel_size + mask = torch.ones((N, mask_channels, H, W), dtype=torch.float32).to(device) + modulate_deform = ModulatedDeformConv( + C, C, kernel_size, padding=padding, bias=False + ).to(device) + output = modulate_deform(inputs, offset, mask) + self.assertTrue(output.shape == inputs.shape) + + @unittest.skipIf(not torch.cuda.is_available(), "Deformable not supported for cpu") + def test_raise_exception(self): + device = torch.device("cuda") + N, C, H, W = shape = 1, 1, 3, 3 + kernel_size = 3 + padding = 1 + + inputs = torch.rand(shape, dtype=torch.float32).to(device) + offset_channels = kernel_size * kernel_size # This is wrong channels for offset + offset = torch.randn((N, offset_channels, H, W), dtype=torch.float32).to(device) + deform = DeformConv(C, C, kernel_size=kernel_size, padding=padding).to(device) + self.assertRaises(RuntimeError, deform, inputs, offset) + + offset_channels = kernel_size * kernel_size * 2 + offset = torch.randn((N, offset_channels, H, W), dtype=torch.float32).to(device) + mask_channels = kernel_size * kernel_size * 2 # This is wrong channels for mask + mask = torch.ones((N, mask_channels, H, W), dtype=torch.float32).to(device) + modulate_deform = ModulatedDeformConv(C, C, kernel_size, padding=padding, bias=False).to( + device + ) + self.assertRaises(RuntimeError, modulate_deform, inputs, offset, mask) + + def test_repr(self): + module = DeformConv(3, 10, kernel_size=3, padding=1, deformable_groups=2) + correct_string = ( + "DeformConv(in_channels=3, out_channels=10, kernel_size=(3, 3), " + "stride=(1, 1), padding=(1, 1), dilation=(1, 1), " + "groups=1, deformable_groups=2, bias=False)" + ) + self.assertEqual(repr(module), correct_string) + + module = ModulatedDeformConv(3, 10, kernel_size=3, padding=1, deformable_groups=2) + correct_string = ( + "ModulatedDeformConv(in_channels=3, out_channels=10, kernel_size=(3, 3), " + "stride=1, padding=1, dilation=1, groups=1, deformable_groups=2, bias=True)" + ) + self.assertEqual(repr(module), correct_string) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/layers/test_losses.py b/data_processing/detectron2/tests/layers/test_losses.py new file mode 100644 index 0000000..d749202 --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_losses.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest +import torch + +from detectron2.layers import ciou_loss, diou_loss + + +class TestLosses(unittest.TestCase): + def test_diou_loss(self): + """ + loss = 1 - iou + d/c + where, + d = (distance between centers of the 2 boxes)^2 + c = (diagonal length of the smallest enclosing box covering the 2 boxes)^2 + """ + # Identical boxes should have loss of 0 + box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32) + loss = diou_loss(box, box) + self.assertTrue(np.allclose(loss, [0.0])) + + # Half size box inside other box + # iou = 0.5, d = 0.25, c = 8 + box2 = torch.tensor([0, -1, 1, 1], dtype=torch.float32) + loss = diou_loss(box, box2) + self.assertTrue(np.allclose(loss, [0.53125])) + + # Two diagonally adjacent boxes + # iou = 0, d = 2, c = 8 + box3 = torch.tensor([0, 0, 1, 1], dtype=torch.float32) + box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32) + loss = diou_loss(box3, box4) + self.assertTrue(np.allclose(loss, [1.25])) + + # Test batched loss and reductions + box1s = torch.stack([box, box3], dim=0) + box2s = torch.stack([box2, box4], dim=0) + + loss = diou_loss(box1s, box2s, reduction="sum") + self.assertTrue(np.allclose(loss, [1.78125])) + + loss = diou_loss(box1s, box2s, reduction="mean") + self.assertTrue(np.allclose(loss, [0.890625])) + + def test_ciou_loss(self): + """ + loss = 1 - iou + d/c + alpha*v + where, + d = (distance between centers of the 2 boxes)^2 + c = (diagonal length of the smallest enclosing box covering the 2 boxes)^2 + v = (4/pi^2) * (arctan(box1_w/box1_h) - arctan(box2_w/box2_h))^2 + alpha = v/(1 - iou + v) + """ + # Identical boxes should have loss of 0 + box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32) + loss = ciou_loss(box, box) + self.assertTrue(np.allclose(loss, [0.0])) + + # Half size box inside other box + # iou = 0.5, d = 0.25, c = 8 + # v = (4/pi^2) * (arctan(1) - arctan(0.5))^2 = 0.042 + # alpha = 0.0775 + box2 = torch.tensor([0, -1, 1, 1], dtype=torch.float32) + loss = ciou_loss(box, box2) + self.assertTrue(np.allclose(loss, [0.5345])) + + # Two diagonally adjacent boxes + # iou = 0, d = 2, c = 8, v = 0, alpha = 0 + box3 = torch.tensor([0, 0, 1, 1], dtype=torch.float32) + box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32) + loss = ciou_loss(box3, box4) + self.assertTrue(np.allclose(loss, [1.25])) + + # Test batched loss and reductions + box1s = torch.stack([box, box3], dim=0) + box2s = torch.stack([box2, box4], dim=0) + + loss = ciou_loss(box1s, box2s, reduction="sum") + self.assertTrue(np.allclose(loss, [1.7845])) + + loss = ciou_loss(box1s, box2s, reduction="mean") + self.assertTrue(np.allclose(loss, [0.89225])) diff --git a/data_processing/detectron2/tests/layers/test_mask_ops.py b/data_processing/detectron2/tests/layers/test_mask_ops.py new file mode 100644 index 0000000..dfbcaf5 --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_mask_ops.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +import io +import numpy as np +import unittest +from collections import defaultdict +import torch +import tqdm +from fvcore.common.benchmark import benchmark +from pycocotools.coco import COCO +from tabulate import tabulate +from torch.nn import functional as F + +from detectron2.data import MetadataCatalog +from detectron2.layers.mask_ops import ( + pad_masks, + paste_mask_in_image_old, + paste_masks_in_image, + scale_boxes, +) +from detectron2.structures import BitMasks, Boxes, BoxMode, PolygonMasks +from detectron2.structures.masks import polygons_to_bitmask +from detectron2.utils.file_io import PathManager +from detectron2.utils.testing import random_boxes + + +def iou_between_full_image_bit_masks(a, b): + intersect = (a & b).sum() + union = (a | b).sum() + return intersect / union + + +def rasterize_polygons_with_grid_sample(full_image_bit_mask, box, mask_size, threshold=0.5): + x0, y0, x1, y1 = box[0], box[1], box[2], box[3] + + img_h, img_w = full_image_bit_mask.shape + + mask_y = np.arange(0.0, mask_size) + 0.5 # mask y sample coords in [0.5, mask_size - 0.5] + mask_x = np.arange(0.0, mask_size) + 0.5 # mask x sample coords in [0.5, mask_size - 0.5] + mask_y = mask_y / mask_size * (y1 - y0) + y0 + mask_x = mask_x / mask_size * (x1 - x0) + x0 + + mask_x = (mask_x - 0.5) / (img_w - 1) * 2 + -1 + mask_y = (mask_y - 0.5) / (img_h - 1) * 2 + -1 + gy, gx = torch.meshgrid(torch.from_numpy(mask_y), torch.from_numpy(mask_x)) + ind = torch.stack([gx, gy], dim=-1).to(dtype=torch.float32) + + full_image_bit_mask = torch.from_numpy(full_image_bit_mask) + mask = F.grid_sample( + full_image_bit_mask[None, None, :, :].to(dtype=torch.float32), + ind[None, :, :, :], + align_corners=True, + ) + + return mask[0, 0] >= threshold + + +class TestMaskCropPaste(unittest.TestCase): + def setUp(self): + json_file = MetadataCatalog.get("coco_2017_val_100").json_file + if not PathManager.isfile(json_file): + raise unittest.SkipTest("{} not found".format(json_file)) + with contextlib.redirect_stdout(io.StringIO()): + json_file = PathManager.get_local_path(json_file) + self.coco = COCO(json_file) + + def test_crop_paste_consistency(self): + """ + rasterize_polygons_within_box (used in training) + and + paste_masks_in_image (used in inference) + should be inverse operations to each other. + + This function runs several implementation of the above two operations and prints + the reconstruction error. + """ + + anns = self.coco.loadAnns(self.coco.getAnnIds(iscrowd=False)) # avoid crowd annotations + + selected_anns = anns[:100] + + ious = [] + for ann in tqdm.tqdm(selected_anns): + results = self.process_annotation(ann) + ious.append([k[2] for k in results]) + + ious = np.array(ious) + mean_ious = ious.mean(axis=0) + table = [] + res_dic = defaultdict(dict) + for row, iou in zip(results, mean_ious): + table.append((row[0], row[1], iou)) + res_dic[row[0]][row[1]] = iou + print(tabulate(table, headers=["rasterize", "paste", "iou"], tablefmt="simple")) + # assert that the reconstruction is good: + self.assertTrue(res_dic["polygon"]["aligned"] > 0.94) + self.assertTrue(res_dic["roialign"]["aligned"] > 0.95) + + def process_annotation(self, ann, mask_side_len=28): + # Parse annotation data + img_info = self.coco.loadImgs(ids=[ann["image_id"]])[0] + height, width = img_info["height"], img_info["width"] + gt_polygons = [np.array(p, dtype=np.float64) for p in ann["segmentation"]] + gt_bbox = BoxMode.convert(ann["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) + gt_bit_mask = polygons_to_bitmask(gt_polygons, height, width) + + # Run rasterize .. + torch_gt_bbox = torch.tensor(gt_bbox).to(dtype=torch.float32).reshape(-1, 4) + box_bitmasks = { + "polygon": PolygonMasks([gt_polygons]).crop_and_resize(torch_gt_bbox, mask_side_len)[0], + "gridsample": rasterize_polygons_with_grid_sample(gt_bit_mask, gt_bbox, mask_side_len), + "roialign": BitMasks(torch.from_numpy(gt_bit_mask[None, :, :])).crop_and_resize( + torch_gt_bbox, mask_side_len + )[0], + } + + # Run paste .. + results = defaultdict(dict) + for k, box_bitmask in box_bitmasks.items(): + padded_bitmask, scale = pad_masks(box_bitmask[None, :, :], 1) + scaled_boxes = scale_boxes(torch_gt_bbox, scale) + + r = results[k] + r["old"] = paste_mask_in_image_old( + padded_bitmask[0], scaled_boxes[0], height, width, threshold=0.5 + ) + r["aligned"] = paste_masks_in_image( + box_bitmask[None, :, :], Boxes(torch_gt_bbox), (height, width) + )[0] + + table = [] + for rasterize_method, r in results.items(): + for paste_method, mask in r.items(): + mask = np.asarray(mask) + iou = iou_between_full_image_bit_masks(gt_bit_mask.astype("uint8"), mask) + table.append((rasterize_method, paste_method, iou)) + return table + + def test_polygon_area(self): + # Draw polygon boxes + for d in [5.0, 10.0, 1000.0]: + polygon = PolygonMasks([[[0, 0, 0, d, d, d, d, 0]]]) + area = polygon.area()[0] + target = d**2 + self.assertEqual(area, target) + + # Draw polygon triangles + for d in [5.0, 10.0, 1000.0]: + polygon = PolygonMasks([[[0, 0, 0, d, d, d]]]) + area = polygon.area()[0] + target = d**2 / 2 + self.assertEqual(area, target) + + def test_paste_mask_scriptable(self): + scripted_f = torch.jit.script(paste_masks_in_image) + N = 10 + masks = torch.rand(N, 28, 28) + boxes = Boxes(random_boxes(N, 100)).tensor + image_shape = (150, 150) + + out = paste_masks_in_image(masks, boxes, image_shape) + scripted_out = scripted_f(masks, boxes, image_shape) + self.assertTrue(torch.equal(out, scripted_out)) + + +def benchmark_paste(): + S = 800 + H, W = image_shape = (S, S) + N = 64 + torch.manual_seed(42) + masks = torch.rand(N, 28, 28) + + center = torch.rand(N, 2) * 600 + 100 + wh = torch.clamp(torch.randn(N, 2) * 40 + 200, min=50) + x0y0 = torch.clamp(center - wh * 0.5, min=0.0) + x1y1 = torch.clamp(center + wh * 0.5, max=S) + boxes = Boxes(torch.cat([x0y0, x1y1], axis=1)) + + def func(device, n=3): + m = masks.to(device=device) + b = boxes.to(device=device) + + def bench(): + for _ in range(n): + paste_masks_in_image(m, b, image_shape) + if device.type == "cuda": + torch.cuda.synchronize() + + return bench + + specs = [{"device": torch.device("cpu"), "n": 3}] + if torch.cuda.is_available(): + specs.append({"device": torch.device("cuda"), "n": 3}) + + benchmark(func, "paste_masks", specs, num_iters=10, warmup_iters=2) + + +if __name__ == "__main__": + benchmark_paste() + unittest.main() diff --git a/data_processing/detectron2/tests/layers/test_nms.py b/data_processing/detectron2/tests/layers/test_nms.py new file mode 100644 index 0000000..a042db6 --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_nms.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import, division, print_function, unicode_literals +import unittest +import torch + +from detectron2.layers import batched_nms +from detectron2.utils.testing import random_boxes + + +class TestNMS(unittest.TestCase): + def _create_tensors(self, N): + boxes = random_boxes(N, 200) + scores = torch.rand(N) + return boxes, scores + + def test_nms_scriptability(self): + N = 2000 + num_classes = 50 + boxes, scores = self._create_tensors(N) + idxs = torch.randint(0, num_classes, (N,)) + scripted_batched_nms = torch.jit.script(batched_nms) + err_msg = "NMS is incompatible with jit-scripted NMS for IoU={}" + + for iou in [0.2, 0.5, 0.8]: + keep_ref = batched_nms(boxes, scores, idxs, iou) + backup = boxes.clone() + scripted_keep = scripted_batched_nms(boxes, scores, idxs, iou) + assert torch.allclose(boxes, backup), "boxes modified by jit-scripted batched_nms" + self.assertTrue(torch.equal(keep_ref, scripted_keep), err_msg.format(iou)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/layers/test_nms_rotated.py b/data_processing/detectron2/tests/layers/test_nms_rotated.py new file mode 100644 index 0000000..4b45384 --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_nms_rotated.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import, division, print_function, unicode_literals +import numpy as np +import unittest +from copy import deepcopy +import torch +from torchvision import ops + +from detectron2.layers import batched_nms, batched_nms_rotated, nms_rotated +from detectron2.utils.testing import random_boxes + + +def nms_edit_distance(keep1, keep2): + """ + Compare the "keep" result of two nms call. + They are allowed to be different in terms of edit distance + due to floating point precision issues, e.g., + if a box happen to have an IoU of 0.5 with another box, + one implentation may choose to keep it while another may discard it. + """ + keep1, keep2 = keep1.cpu(), keep2.cpu() + if torch.equal(keep1, keep2): + # they should be equal most of the time + return 0 + keep1, keep2 = tuple(keep1), tuple(keep2) + m, n = len(keep1), len(keep2) + + # edit distance with DP + f = [np.arange(n + 1), np.arange(n + 1)] + for i in range(m): + cur_row = i % 2 + other_row = (i + 1) % 2 + f[other_row][0] = i + 1 + for j in range(n): + f[other_row][j + 1] = ( + f[cur_row][j] + if keep1[i] == keep2[j] + else min(min(f[cur_row][j], f[cur_row][j + 1]), f[other_row][j]) + 1 + ) + return f[m % 2][n] + + +class TestNMSRotated(unittest.TestCase): + def reference_horizontal_nms(self, boxes, scores, iou_threshold): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + (Note here 5 == 4 + 1, i.e., 4-dim horizontal box + 1-dim prob) + iou_threshold: intersection over union threshold. + Returns: + picked: a list of indexes of the kept boxes + """ + picked = [] + _, indexes = scores.sort(descending=True) + while len(indexes) > 0: + current = indexes[0] + picked.append(current.item()) + if len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[1:] + rest_boxes = boxes[indexes, :] + iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1) + indexes = indexes[iou <= iou_threshold] + + return torch.as_tensor(picked) + + def _create_tensors(self, N, device="cpu"): + boxes = random_boxes(N, 200, device=device) + scores = torch.rand(N, device=device) + return boxes, scores + + def test_batched_nms_rotated_0_degree_cpu(self, device="cpu"): + N = 2000 + num_classes = 50 + boxes, scores = self._create_tensors(N, device=device) + idxs = torch.randint(0, num_classes, (N,)) + rotated_boxes = torch.zeros(N, 5, device=device) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + err_msg = "Rotated NMS with 0 degree is incompatible with horizontal NMS for IoU={}" + for iou in [0.2, 0.5, 0.8]: + backup = boxes.clone() + keep_ref = batched_nms(boxes, scores, idxs, iou) + assert torch.allclose(boxes, backup), "boxes modified by batched_nms" + backup = rotated_boxes.clone() + keep = batched_nms_rotated(rotated_boxes, scores, idxs, iou) + assert torch.allclose( + rotated_boxes, backup + ), "rotated_boxes modified by batched_nms_rotated" + # Occasionally the gap can be large if there are many IOU on the threshold boundary + self.assertLessEqual(nms_edit_distance(keep, keep_ref), 5, err_msg.format(iou)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_batched_nms_rotated_0_degree_cuda(self): + self.test_batched_nms_rotated_0_degree_cpu(device="cuda") + + def test_nms_rotated_0_degree_cpu(self, device="cpu"): + N = 1000 + boxes, scores = self._create_tensors(N, device=device) + rotated_boxes = torch.zeros(N, 5, device=device) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}" + for iou in [0.2, 0.5, 0.8]: + keep_ref = self.reference_horizontal_nms(boxes, scores, iou) + keep = nms_rotated(rotated_boxes, scores, iou) + self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_nms_rotated_0_degree_cuda(self): + self.test_nms_rotated_0_degree_cpu(device="cuda") + + def test_nms_rotated_90_degrees_cpu(self): + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + # Note for rotated_boxes[:, 2] and rotated_boxes[:, 3]: + # widths and heights are intentionally swapped here for 90 degrees case + # so that the reference horizontal nms could be used + rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] + + rotated_boxes[:, 4] = torch.ones(N) * 90 + err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}" + for iou in [0.2, 0.5, 0.8]: + keep_ref = self.reference_horizontal_nms(boxes, scores, iou) + keep = nms_rotated(rotated_boxes, scores, iou) + self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou)) + + def test_nms_rotated_180_degrees_cpu(self): + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 4] = torch.ones(N) * 180 + err_msg = "Rotated NMS incompatible between CPU and reference implementation for IoU={}" + for iou in [0.2, 0.5, 0.8]: + keep_ref = self.reference_horizontal_nms(boxes, scores, iou) + keep = nms_rotated(rotated_boxes, scores, iou) + self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou)) + + +class TestScriptable(unittest.TestCase): + def setUp(self): + class TestingModule(torch.nn.Module): + def forward(self, boxes, scores, threshold): + return nms_rotated(boxes, scores, threshold) + + self.module = TestingModule() + + def test_scriptable_cpu(self): + m = deepcopy(self.module).cpu() + _ = torch.jit.script(m) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_scriptable_cuda(self): + m = deepcopy(self.module).cuda() + _ = torch.jit.script(m) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/layers/test_roi_align.py b/data_processing/detectron2/tests/layers/test_roi_align.py new file mode 100644 index 0000000..b6fd8ed --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_roi_align.py @@ -0,0 +1,210 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest +from copy import copy +import cv2 +import torch +from fvcore.common.benchmark import benchmark +from torch.nn import functional as F + +from detectron2.layers.roi_align import ROIAlign, roi_align + + +class ROIAlignTest(unittest.TestCase): + def test_forward_output(self): + input = np.arange(25).reshape(5, 5).astype("float32") + """ + 0 1 2 3 4 + 5 6 7 8 9 + 10 11 12 13 14 + 15 16 17 18 19 + 20 21 22 23 24 + """ + + output = self._simple_roialign(input, [1, 1, 3, 3], (4, 4), aligned=False) + output_correct = self._simple_roialign(input, [1, 1, 3, 3], (4, 4), aligned=True) + + # without correction: + old_results = [ + [7.5, 8, 8.5, 9], + [10, 10.5, 11, 11.5], + [12.5, 13, 13.5, 14], + [15, 15.5, 16, 16.5], + ] + + # with 0.5 correction: + correct_results = [ + [4.5, 5.0, 5.5, 6.0], + [7.0, 7.5, 8.0, 8.5], + [9.5, 10.0, 10.5, 11.0], + [12.0, 12.5, 13.0, 13.5], + ] + # This is an upsampled version of [[6, 7], [11, 12]] + + self.assertTrue(np.allclose(output.flatten(), np.asarray(old_results).flatten())) + self.assertTrue( + np.allclose(output_correct.flatten(), np.asarray(correct_results).flatten()) + ) + + # Also see similar issues in tensorflow at + # https://github.com/tensorflow/tensorflow/issues/26278 + + def test_resize(self): + H, W = 30, 30 + input = np.random.rand(H, W).astype("float32") * 100 + box = [10, 10, 20, 20] + output = self._simple_roialign(input, box, (5, 5), aligned=True) + + input2x = cv2.resize(input, (W // 2, H // 2), interpolation=cv2.INTER_LINEAR) + box2x = [x / 2 for x in box] + output2x = self._simple_roialign(input2x, box2x, (5, 5), aligned=True) + diff = np.abs(output2x - output) + self.assertTrue(diff.max() < 1e-4) + + def test_grid_sample_equivalence(self): + H, W = 30, 30 + input = np.random.rand(H, W).astype("float32") * 100 + box = [10, 10, 20, 20] + for ratio in [1, 2, 3]: + output = self._simple_roialign(input, box, (5, 5), sampling_ratio=ratio) + output_grid_sample = grid_sample_roi_align( + torch.from_numpy(input[None, None, :, :]).float(), + torch.as_tensor(box).float()[None, :], + 5, + 1.0, + ratio, + ) + self.assertTrue(torch.allclose(output, output_grid_sample)) + + def _simple_roialign(self, img, box, resolution, sampling_ratio=0, aligned=True): + """ + RoiAlign with scale 1.0. + """ + if isinstance(resolution, int): + resolution = (resolution, resolution) + op = ROIAlign(resolution, 1.0, sampling_ratio, aligned=aligned) + input = torch.from_numpy(img[None, None, :, :].astype("float32")) + + rois = [0] + list(box) + rois = torch.from_numpy(np.asarray(rois)[None, :].astype("float32")) + output = op.forward(input, rois) + if torch.cuda.is_available(): + output_cuda = op.forward(input.cuda(), rois.cuda()).cpu() + self.assertTrue(torch.allclose(output, output_cuda)) + return output[0, 0] + + def _simple_roialign_with_grad(self, img, box, resolution, device): + if isinstance(resolution, int): + resolution = (resolution, resolution) + + op = ROIAlign(resolution, 1.0, 0, aligned=True) + input = torch.from_numpy(img[None, None, :, :].astype("float32")) + + rois = [0] + list(box) + rois = torch.from_numpy(np.asarray(rois)[None, :].astype("float32")) + input = input.to(device=device) + rois = rois.to(device=device) + input.requires_grad = True + output = op.forward(input, rois) + return input, output + + def test_empty_box(self): + img = np.random.rand(5, 5) + box = [3, 4, 5, 4] + o = self._simple_roialign(img, box, 7) + self.assertTrue(o.shape == (7, 7)) + self.assertTrue((o == 0).all()) + + for dev in ["cpu"] + ["cuda"] if torch.cuda.is_available() else []: + input, output = self._simple_roialign_with_grad(img, box, 7, torch.device(dev)) + output.sum().backward() + self.assertTrue(torch.allclose(input.grad, torch.zeros_like(input))) + + def test_empty_batch(self): + input = torch.zeros(0, 3, 10, 10, dtype=torch.float32) + rois = torch.zeros(0, 5, dtype=torch.float32) + op = ROIAlign((7, 7), 1.0, 0, aligned=True) + output = op.forward(input, rois) + self.assertTrue(output.shape == (0, 3, 7, 7)) + + +def grid_sample_roi_align(input, boxes, output_size, scale, sampling_ratio): + # unlike true roi_align, this does not support different batch_idx + from detectron2.projects.point_rend.point_features import ( + generate_regular_grid_point_coords, + get_point_coords_wrt_image, + point_sample, + ) + + N, _, H, W = input.shape + R = len(boxes) + assert N == 1 + boxes = boxes * scale + grid = generate_regular_grid_point_coords(R, output_size * sampling_ratio, device=boxes.device) + coords = get_point_coords_wrt_image(boxes, grid) + coords = coords / torch.as_tensor([W, H], device=coords.device) # R, s^2, 2 + res = point_sample(input, coords.unsqueeze(0), align_corners=False) # 1,C, R,s^2 + res = ( + res.squeeze(0) + .permute(1, 0, 2) + .reshape(R, -1, output_size * sampling_ratio, output_size * sampling_ratio) + ) + res = F.avg_pool2d(res, sampling_ratio) + return res + + +def benchmark_roi_align(): + def random_boxes(mean_box, stdev, N, maxsize): + ret = torch.rand(N, 4) * stdev + torch.tensor(mean_box, dtype=torch.float) + ret.clamp_(min=0, max=maxsize) + return ret + + def func(shape, nboxes_per_img, sampling_ratio, device, box_size="large"): + N, _, H, _ = shape + input = torch.rand(*shape) + boxes = [] + batch_idx = [] + for k in range(N): + if box_size == "large": + b = random_boxes([80, 80, 130, 130], 24, nboxes_per_img, H) + else: + b = random_boxes([100, 100, 110, 110], 4, nboxes_per_img, H) + boxes.append(b) + batch_idx.append(torch.zeros(nboxes_per_img, 1, dtype=torch.float32) + k) + boxes = torch.cat(boxes, axis=0) + batch_idx = torch.cat(batch_idx, axis=0) + boxes = torch.cat([batch_idx, boxes], axis=1) + + input = input.to(device=device) + boxes = boxes.to(device=device) + + def bench(): + if False and sampling_ratio > 0 and N == 1: + # enable to benchmark grid_sample (slower) + grid_sample_roi_align(input, boxes[:, 1:], 7, 1.0, sampling_ratio) + else: + roi_align(input, boxes, 7, 1.0, sampling_ratio, True) + if device == "cuda": + torch.cuda.synchronize() + + return bench + + def gen_args(arg): + args = [] + for size in ["small", "large"]: + for ratio in [0, 2]: + args.append(copy(arg)) + args[-1]["sampling_ratio"] = ratio + args[-1]["box_size"] = size + return args + + arg = dict(shape=(1, 512, 256, 256), nboxes_per_img=512, device="cuda") + benchmark(func, "cuda_roialign", gen_args(arg), num_iters=20, warmup_iters=1) + arg.update({"device": "cpu", "shape": (1, 256, 128, 128)}) + benchmark(func, "cpu_roialign", gen_args(arg), num_iters=5, warmup_iters=1) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + benchmark_roi_align() + unittest.main() diff --git a/data_processing/detectron2/tests/layers/test_roi_align_rotated.py b/data_processing/detectron2/tests/layers/test_roi_align_rotated.py new file mode 100644 index 0000000..7323d7d --- /dev/null +++ b/data_processing/detectron2/tests/layers/test_roi_align_rotated.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import cv2 +import torch +from torch.autograd import Variable, gradcheck + +from detectron2.layers.roi_align import ROIAlign +from detectron2.layers.roi_align_rotated import ROIAlignRotated + +logger = logging.getLogger(__name__) + + +class ROIAlignRotatedTest(unittest.TestCase): + def _box_to_rotated_box(self, box, angle): + return [ + (box[0] + box[2]) / 2.0, + (box[1] + box[3]) / 2.0, + box[2] - box[0], + box[3] - box[1], + angle, + ] + + def _rot90(self, img, num): + num = num % 4 # note: -1 % 4 == 3 + for _ in range(num): + img = img.transpose(0, 1).flip(0) + return img + + def test_forward_output_0_90_180_270(self): + for i in range(4): + # i = 0, 1, 2, 3 corresponding to 0, 90, 180, 270 degrees + img = torch.arange(25, dtype=torch.float32).reshape(5, 5) + """ + 0 1 2 3 4 + 5 6 7 8 9 + 10 11 12 13 14 + 15 16 17 18 19 + 20 21 22 23 24 + """ + box = [1, 1, 3, 3] + rotated_box = self._box_to_rotated_box(box=box, angle=90 * i) + + result = self._simple_roi_align_rotated(img=img, box=rotated_box, resolution=(4, 4)) + + # Here's an explanation for 0 degree case: + # point 0 in the original input lies at [0.5, 0.5] + # (the center of bin [0, 1] x [0, 1]) + # point 1 in the original input lies at [1.5, 0.5], etc. + # since the resolution is (4, 4) that divides [1, 3] x [1, 3] + # into 4 x 4 equal bins, + # the top-left bin is [1, 1.5] x [1, 1.5], and its center + # (1.25, 1.25) lies at the 3/4 position + # between point 0 and point 1, point 5 and point 6, + # point 0 and point 5, point 1 and point 6, so it can be calculated as + # 0.25*(0*0.25+1*0.75)+(5*0.25+6*0.75)*0.75 = 4.5 + result_expected = torch.tensor( + [ + [4.5, 5.0, 5.5, 6.0], + [7.0, 7.5, 8.0, 8.5], + [9.5, 10.0, 10.5, 11.0], + [12.0, 12.5, 13.0, 13.5], + ] + ) + # This is also an upsampled version of [[6, 7], [11, 12]] + + # When the box is rotated by 90 degrees CCW, + # the result would be rotated by 90 degrees CW, thus it's -i here + result_expected = self._rot90(result_expected, -i) + + assert torch.allclose(result, result_expected) + + def test_resize(self): + H, W = 30, 30 + input = torch.rand(H, W) * 100 + box = [10, 10, 20, 20] + rotated_box = self._box_to_rotated_box(box, angle=0) + output = self._simple_roi_align_rotated(img=input, box=rotated_box, resolution=(5, 5)) + + input2x = cv2.resize(input.numpy(), (W // 2, H // 2), interpolation=cv2.INTER_LINEAR) + input2x = torch.from_numpy(input2x) + box2x = [x / 2 for x in box] + rotated_box2x = self._box_to_rotated_box(box2x, angle=0) + output2x = self._simple_roi_align_rotated(img=input2x, box=rotated_box2x, resolution=(5, 5)) + assert torch.allclose(output2x, output) + + def _simple_roi_align_rotated(self, img, box, resolution): + """ + RoiAlignRotated with scale 1.0 and 0 sample ratio. + """ + op = ROIAlignRotated(output_size=resolution, spatial_scale=1.0, sampling_ratio=0) + input = img[None, None, :, :] + + rois = [0] + list(box) + rois = torch.tensor(rois, dtype=torch.float32)[None, :] + result_cpu = op.forward(input, rois) + if torch.cuda.is_available(): + result_cuda = op.forward(input.cuda(), rois.cuda()) + assert torch.allclose(result_cpu, result_cuda.cpu()) + return result_cpu[0, 0] + + def test_empty_box(self): + img = torch.rand(5, 5) + out = self._simple_roi_align_rotated(img, [2, 3, 0, 0, 0], (7, 7)) + self.assertTrue((out == 0).all()) + + def test_roi_align_rotated_gradcheck_cpu(self): + dtype = torch.float64 + device = torch.device("cpu") + roi_align_rotated_op = ROIAlignRotated( + output_size=(5, 5), spatial_scale=0.5, sampling_ratio=1 + ).to(dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + # roi format is (batch index, x_center, y_center, width, height, angle) + rois = torch.tensor( + [[0, 4.5, 4.5, 9, 9, 0], [0, 2, 7, 4, 4, 0], [0, 7, 7, 4, 4, 0]], + dtype=dtype, + device=device, + ) + + def func(input): + return roi_align_rotated_op(input, rois) + + assert gradcheck(func, (x,)), "gradcheck failed for RoIAlignRotated CPU" + assert gradcheck(func, (x.transpose(2, 3),)), "gradcheck failed for RoIAlignRotated CPU" + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_roi_align_rotated_gradient_cuda(self): + """ + Compute gradients for ROIAlignRotated with multiple bounding boxes on the GPU, + and compare the result with ROIAlign + """ + # torch.manual_seed(123) + dtype = torch.float64 + device = torch.device("cuda") + pool_h, pool_w = (5, 5) + + roi_align = ROIAlign(output_size=(pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to( + device=device + ) + + roi_align_rotated = ROIAlignRotated( + output_size=(pool_h, pool_w), spatial_scale=1, sampling_ratio=2 + ).to(device=device) + + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + # x_rotated = x.clone() won't work (will lead to grad_fun=CloneBackward)! + x_rotated = Variable(x.data.clone(), requires_grad=True) + + # roi_rotated format is (batch index, x_center, y_center, width, height, angle) + rois_rotated = torch.tensor( + [[0, 4.5, 4.5, 9, 9, 0], [0, 2, 7, 4, 4, 0], [0, 7, 7, 4, 4, 0]], + dtype=dtype, + device=device, + ) + + y_rotated = roi_align_rotated(x_rotated, rois_rotated) + s_rotated = y_rotated.sum() + s_rotated.backward() + + # roi format is (batch index, x1, y1, x2, y2) + rois = torch.tensor( + [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9]], dtype=dtype, device=device + ) + + y = roi_align(x, rois) + s = y.sum() + s.backward() + + assert torch.allclose( + x.grad, x_rotated.grad + ), "gradients for ROIAlign and ROIAlignRotated mismatch on CUDA" + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/__init__.py b/data_processing/detectron2/tests/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tests/modeling/test_anchor_generator.py b/data_processing/detectron2/tests/modeling/test_anchor_generator.py new file mode 100644 index 0000000..13a808e --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_anchor_generator.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import torch + +from detectron2.config import get_cfg +from detectron2.layers import ShapeSpec +from detectron2.modeling.anchor_generator import DefaultAnchorGenerator, RotatedAnchorGenerator + +logger = logging.getLogger(__name__) + + +class TestAnchorGenerator(unittest.TestCase): + def test_default_anchor_generator(self): + cfg = get_cfg() + cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64]] + cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.25, 1, 4]] + + anchor_generator = DefaultAnchorGenerator(cfg, [ShapeSpec(stride=4)]) + + # only the last two dimensions of features matter here + num_images = 2 + features = {"stage3": torch.rand(num_images, 96, 1, 2)} + anchors = anchor_generator([features["stage3"]]) + expected_anchor_tensor = torch.tensor( + [ + [-32.0, -8.0, 32.0, 8.0], + [-16.0, -16.0, 16.0, 16.0], + [-8.0, -32.0, 8.0, 32.0], + [-64.0, -16.0, 64.0, 16.0], + [-32.0, -32.0, 32.0, 32.0], + [-16.0, -64.0, 16.0, 64.0], + [-28.0, -8.0, 36.0, 8.0], # -28.0 == -32.0 + STRIDE (4) + [-12.0, -16.0, 20.0, 16.0], + [-4.0, -32.0, 12.0, 32.0], + [-60.0, -16.0, 68.0, 16.0], + [-28.0, -32.0, 36.0, 32.0], + [-12.0, -64.0, 20.0, 64.0], + ] + ) + + self.assertTrue(torch.allclose(anchors[0].tensor, expected_anchor_tensor)) + + def test_default_anchor_generator_centered(self): + # test explicit args + anchor_generator = DefaultAnchorGenerator( + sizes=[32, 64], aspect_ratios=[0.25, 1, 4], strides=[4] + ) + + # only the last two dimensions of features matter here + num_images = 2 + features = {"stage3": torch.rand(num_images, 96, 1, 2)} + expected_anchor_tensor = torch.tensor( + [ + [-30.0, -6.0, 34.0, 10.0], + [-14.0, -14.0, 18.0, 18.0], + [-6.0, -30.0, 10.0, 34.0], + [-62.0, -14.0, 66.0, 18.0], + [-30.0, -30.0, 34.0, 34.0], + [-14.0, -62.0, 18.0, 66.0], + [-26.0, -6.0, 38.0, 10.0], + [-10.0, -14.0, 22.0, 18.0], + [-2.0, -30.0, 14.0, 34.0], + [-58.0, -14.0, 70.0, 18.0], + [-26.0, -30.0, 38.0, 34.0], + [-10.0, -62.0, 22.0, 66.0], + ] + ) + + anchors = anchor_generator([features["stage3"]]) + self.assertTrue(torch.allclose(anchors[0].tensor, expected_anchor_tensor)) + + anchors = torch.jit.script(anchor_generator)([features["stage3"]]) + self.assertTrue(torch.allclose(anchors[0].tensor, expected_anchor_tensor)) + + def test_rrpn_anchor_generator(self): + cfg = get_cfg() + cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64]] + cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.25, 1, 4]] + cfg.MODEL.ANCHOR_GENERATOR.ANGLES = [0, 45] # test single list[float] + anchor_generator = RotatedAnchorGenerator(cfg, [ShapeSpec(stride=4)]) + + # only the last two dimensions of features matter here + num_images = 2 + features = {"stage3": torch.rand(num_images, 96, 1, 2)} + anchors = anchor_generator([features["stage3"]]) + expected_anchor_tensor = torch.tensor( + [ + [0.0, 0.0, 64.0, 16.0, 0.0], + [0.0, 0.0, 64.0, 16.0, 45.0], + [0.0, 0.0, 32.0, 32.0, 0.0], + [0.0, 0.0, 32.0, 32.0, 45.0], + [0.0, 0.0, 16.0, 64.0, 0.0], + [0.0, 0.0, 16.0, 64.0, 45.0], + [0.0, 0.0, 128.0, 32.0, 0.0], + [0.0, 0.0, 128.0, 32.0, 45.0], + [0.0, 0.0, 64.0, 64.0, 0.0], + [0.0, 0.0, 64.0, 64.0, 45.0], + [0.0, 0.0, 32.0, 128.0, 0.0], + [0.0, 0.0, 32.0, 128.0, 45.0], + [4.0, 0.0, 64.0, 16.0, 0.0], # 4.0 == 0.0 + STRIDE (4) + [4.0, 0.0, 64.0, 16.0, 45.0], + [4.0, 0.0, 32.0, 32.0, 0.0], + [4.0, 0.0, 32.0, 32.0, 45.0], + [4.0, 0.0, 16.0, 64.0, 0.0], + [4.0, 0.0, 16.0, 64.0, 45.0], + [4.0, 0.0, 128.0, 32.0, 0.0], + [4.0, 0.0, 128.0, 32.0, 45.0], + [4.0, 0.0, 64.0, 64.0, 0.0], + [4.0, 0.0, 64.0, 64.0, 45.0], + [4.0, 0.0, 32.0, 128.0, 0.0], + [4.0, 0.0, 32.0, 128.0, 45.0], + ] + ) + + self.assertTrue(torch.allclose(anchors[0].tensor, expected_anchor_tensor)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_backbone.py b/data_processing/detectron2/tests/modeling/test_backbone.py new file mode 100644 index 0000000..3bb100f --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_backbone.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import unittest +import torch + +import detectron2.export.torchscript # apply patch # noqa +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone import build_resnet_backbone +from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone + + +class TestBackBone(unittest.TestCase): + def test_resnet_scriptability(self): + cfg = get_cfg() + resnet = build_resnet_backbone(cfg, ShapeSpec(channels=3)) + + scripted_resnet = torch.jit.script(resnet) + + inp = torch.rand(2, 3, 100, 100) + out1 = resnet(inp)["res4"] + out2 = scripted_resnet(inp)["res4"] + self.assertTrue(torch.allclose(out1, out2)) + + def test_fpn_scriptability(self): + cfg = model_zoo.get_config("Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml") + bb = build_resnet_fpn_backbone(cfg, ShapeSpec(channels=3)) + bb_s = torch.jit.script(bb) + + inp = torch.rand(2, 3, 128, 128) + out1 = bb(inp)["p5"] + out2 = bb_s(inp)["p5"] + self.assertTrue(torch.allclose(out1, out2)) diff --git a/data_processing/detectron2/tests/modeling/test_box2box_transform.py b/data_processing/detectron2/tests/modeling/test_box2box_transform.py new file mode 100644 index 0000000..fd3a7b7 --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_box2box_transform.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import torch + +from detectron2.modeling.box_regression import ( + Box2BoxTransform, + Box2BoxTransformLinear, + Box2BoxTransformRotated, +) +from detectron2.utils.testing import random_boxes + +logger = logging.getLogger(__name__) + + +class TestBox2BoxTransform(unittest.TestCase): + def test_reconstruction(self): + weights = (5, 5, 10, 10) + b2b_tfm = Box2BoxTransform(weights=weights) + src_boxes = random_boxes(10) + dst_boxes = random_boxes(10) + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + src_boxes = src_boxes.to(device=device) + dst_boxes = dst_boxes.to(device=device) + deltas = b2b_tfm.get_deltas(src_boxes, dst_boxes) + dst_boxes_reconstructed = b2b_tfm.apply_deltas(deltas, src_boxes) + self.assertTrue(torch.allclose(dst_boxes, dst_boxes_reconstructed)) + + def test_apply_deltas_tracing(self): + weights = (5, 5, 10, 10) + b2b_tfm = Box2BoxTransform(weights=weights) + + with torch.no_grad(): + func = torch.jit.trace(b2b_tfm.apply_deltas, (torch.randn(10, 20), torch.randn(10, 4))) + + o = func(torch.randn(10, 20), torch.randn(10, 4)) + self.assertEqual(o.shape, (10, 20)) + o = func(torch.randn(5, 20), torch.randn(5, 4)) + self.assertEqual(o.shape, (5, 20)) + + +def random_rotated_boxes(mean_box, std_length, std_angle, N): + return torch.cat( + [torch.rand(N, 4) * std_length, torch.rand(N, 1) * std_angle], dim=1 + ) + torch.tensor(mean_box, dtype=torch.float) + + +class TestBox2BoxTransformRotated(unittest.TestCase): + def test_reconstruction(self): + weights = (5, 5, 10, 10, 1) + b2b_transform = Box2BoxTransformRotated(weights=weights) + src_boxes = random_rotated_boxes([10, 10, 20, 20, -30], 5, 60.0, 10) + dst_boxes = random_rotated_boxes([10, 10, 20, 20, -30], 5, 60.0, 10) + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + src_boxes = src_boxes.to(device=device) + dst_boxes = dst_boxes.to(device=device) + deltas = b2b_transform.get_deltas(src_boxes, dst_boxes) + dst_boxes_reconstructed = b2b_transform.apply_deltas(deltas, src_boxes) + assert torch.allclose(dst_boxes[:, :4], dst_boxes_reconstructed[:, :4], atol=1e-5) + # angle difference has to be normalized + assert torch.allclose( + (dst_boxes[:, 4] - dst_boxes_reconstructed[:, 4] + 180.0) % 360.0 - 180.0, + torch.zeros_like(dst_boxes[:, 4]), + atol=1e-4, + ) + + +class TestBox2BoxTransformLinear(unittest.TestCase): + def test_reconstruction(self): + b2b_tfm = Box2BoxTransformLinear() + src_boxes = random_boxes(10) + dst_boxes = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float() + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + src_boxes = src_boxes.to(device=device) + dst_boxes = dst_boxes.to(device=device) + deltas = b2b_tfm.get_deltas(src_boxes, dst_boxes) + dst_boxes_reconstructed = b2b_tfm.apply_deltas(deltas, src_boxes) + self.assertTrue(torch.allclose(dst_boxes, dst_boxes_reconstructed, atol=1e-3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_fast_rcnn.py b/data_processing/detectron2/tests/modeling/test_fast_rcnn.py new file mode 100644 index 0000000..e29b944 --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_fast_rcnn.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import torch + +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform, Box2BoxTransformRotated +from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers +from detectron2.modeling.roi_heads.rotated_fast_rcnn import RotatedFastRCNNOutputLayers +from detectron2.structures import Boxes, Instances, RotatedBoxes +from detectron2.utils.events import EventStorage + +logger = logging.getLogger(__name__) + + +class FastRCNNTest(unittest.TestCase): + def test_fast_rcnn(self): + torch.manual_seed(132) + + box_head_output_size = 8 + + box_predictor = FastRCNNOutputLayers( + ShapeSpec(channels=box_head_output_size), + box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)), + num_classes=5, + ) + feature_pooled = torch.rand(2, box_head_output_size) + predictions = box_predictor(feature_pooled) + + proposal_boxes = torch.tensor([[0.8, 1.1, 3.2, 2.8], [2.3, 2.5, 7, 8]], dtype=torch.float32) + gt_boxes = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6]], dtype=torch.float32) + proposal = Instances((10, 10)) + proposal.proposal_boxes = Boxes(proposal_boxes) + proposal.gt_boxes = Boxes(gt_boxes) + proposal.gt_classes = torch.tensor([1, 2]) + + with EventStorage(): # capture events in a new storage to discard them + losses = box_predictor.losses(predictions, [proposal]) + + expected_losses = { + "loss_cls": torch.tensor(1.7951188087), + "loss_box_reg": torch.tensor(4.0357131958), + } + for name in expected_losses.keys(): + assert torch.allclose(losses[name], expected_losses[name]) + + def test_fast_rcnn_empty_batch(self, device="cpu"): + box_predictor = FastRCNNOutputLayers( + ShapeSpec(channels=10), + box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)), + num_classes=8, + ).to(device=device) + + logits = torch.randn(0, 100, requires_grad=True, device=device) + deltas = torch.randn(0, 4, requires_grad=True, device=device) + losses = box_predictor.losses([logits, deltas], []) + for value in losses.values(): + self.assertTrue(torch.allclose(value, torch.zeros_like(value))) + sum(losses.values()).backward() + self.assertTrue(logits.grad is not None) + self.assertTrue(deltas.grad is not None) + + predictions, _ = box_predictor.inference([logits, deltas], []) + self.assertEqual(len(predictions), 0) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_fast_rcnn_empty_batch_cuda(self): + self.test_fast_rcnn_empty_batch(device=torch.device("cuda")) + + def test_fast_rcnn_rotated(self): + torch.manual_seed(132) + box_head_output_size = 8 + + box_predictor = RotatedFastRCNNOutputLayers( + ShapeSpec(channels=box_head_output_size), + box2box_transform=Box2BoxTransformRotated(weights=(10, 10, 5, 5, 1)), + num_classes=5, + ) + feature_pooled = torch.rand(2, box_head_output_size) + predictions = box_predictor(feature_pooled) + proposal_boxes = torch.tensor( + [[2, 1.95, 2.4, 1.7, 0], [4.65, 5.25, 4.7, 5.5, 0]], dtype=torch.float32 + ) + gt_boxes = torch.tensor([[2, 2, 2, 2, 0], [4, 4, 4, 4, 0]], dtype=torch.float32) + proposal = Instances((10, 10)) + proposal.proposal_boxes = RotatedBoxes(proposal_boxes) + proposal.gt_boxes = RotatedBoxes(gt_boxes) + proposal.gt_classes = torch.tensor([1, 2]) + + with EventStorage(): # capture events in a new storage to discard them + losses = box_predictor.losses(predictions, [proposal]) + + # Note: the expected losses are slightly different even if + # the boxes are essentially the same as in the FastRCNNOutput test, because + # bbox_pred in FastRCNNOutputLayers have different Linear layers/initialization + # between the two cases. + expected_losses = { + "loss_cls": torch.tensor(1.7920907736), + "loss_box_reg": torch.tensor(4.0410838127), + } + for name in expected_losses.keys(): + assert torch.allclose(losses[name], expected_losses[name]) + + def test_predict_boxes_tracing(self): + class Model(torch.nn.Module): + def __init__(self, output_layer): + super(Model, self).__init__() + self._output_layer = output_layer + + def forward(self, proposal_deltas, proposal_boxes): + instances = Instances((10, 10)) + instances.proposal_boxes = Boxes(proposal_boxes) + return self._output_layer.predict_boxes((None, proposal_deltas), [instances]) + + box_head_output_size = 8 + + box_predictor = FastRCNNOutputLayers( + ShapeSpec(channels=box_head_output_size), + box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)), + num_classes=5, + ) + + model = Model(box_predictor) + + from detectron2.export.torchscript_patch import patch_builtin_len + + with torch.no_grad(), patch_builtin_len(): + func = torch.jit.trace(model, (torch.randn(10, 20), torch.randn(10, 4))) + + o = func(torch.randn(10, 20), torch.randn(10, 4)) + self.assertEqual(o[0].shape, (10, 20)) + o = func(torch.randn(5, 20), torch.randn(5, 4)) + self.assertEqual(o[0].shape, (5, 20)) + o = func(torch.randn(20, 20), torch.randn(20, 4)) + self.assertEqual(o[0].shape, (20, 20)) + + def test_predict_probs_tracing(self): + class Model(torch.nn.Module): + def __init__(self, output_layer): + super(Model, self).__init__() + self._output_layer = output_layer + + def forward(self, scores, proposal_boxes): + instances = Instances((10, 10)) + instances.proposal_boxes = Boxes(proposal_boxes) + return self._output_layer.predict_probs((scores, None), [instances]) + + box_head_output_size = 8 + + box_predictor = FastRCNNOutputLayers( + ShapeSpec(channels=box_head_output_size), + box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)), + num_classes=5, + ) + + model = Model(box_predictor) + + from detectron2.export.torchscript_patch import patch_builtin_len + + with torch.no_grad(), patch_builtin_len(): + func = torch.jit.trace(model, (torch.randn(10, 6), torch.rand(10, 4))) + o = func(torch.randn(10, 6), torch.randn(10, 4)) + self.assertEqual(o[0].shape, (10, 6)) + o = func(torch.randn(5, 6), torch.randn(5, 4)) + self.assertEqual(o[0].shape, (5, 6)) + o = func(torch.randn(20, 6), torch.randn(20, 4)) + self.assertEqual(o[0].shape, (20, 6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_matcher.py b/data_processing/detectron2/tests/modeling/test_matcher.py new file mode 100644 index 0000000..6eb2db0 --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_matcher.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest +from typing import List +import torch + +from detectron2.config import get_cfg +from detectron2.modeling.matcher import Matcher + + +class TestMatcher(unittest.TestCase): + def test_scriptability(self): + cfg = get_cfg() + anchor_matcher = Matcher( + cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True + ) + match_quality_matrix = torch.tensor( + [[0.15, 0.45, 0.2, 0.6], [0.3, 0.65, 0.05, 0.1], [0.05, 0.4, 0.25, 0.4]] + ) + expected_matches = torch.tensor([1, 1, 2, 0]) + expected_match_labels = torch.tensor([-1, 1, 0, 1], dtype=torch.int8) + + matches, match_labels = anchor_matcher(match_quality_matrix) + self.assertTrue(torch.allclose(matches, expected_matches)) + self.assertTrue(torch.allclose(match_labels, expected_match_labels)) + + # nonzero_tuple must be import explicitly to let jit know what it is. + # https://github.com/pytorch/pytorch/issues/38964 + from detectron2.layers import nonzero_tuple # noqa F401 + + def f(thresholds: List[float], labels: List[int]): + return Matcher(thresholds, labels, allow_low_quality_matches=True) + + scripted_anchor_matcher = torch.jit.script(f)( + cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS + ) + matches, match_labels = scripted_anchor_matcher(match_quality_matrix) + self.assertTrue(torch.allclose(matches, expected_matches)) + self.assertTrue(torch.allclose(match_labels, expected_match_labels)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_mmdet.py b/data_processing/detectron2/tests/modeling/test_mmdet.py new file mode 100644 index 0000000..a743b0b --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_mmdet.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest + +from detectron2.layers import ShapeSpec +from detectron2.modeling.mmdet_wrapper import MMDetBackbone, MMDetDetector + +try: + import mmdet.models # noqa + + HAS_MMDET = True +except ImportError: + HAS_MMDET = False + + +@unittest.skipIf(not HAS_MMDET, "mmdet not available") +class TestMMDetWrapper(unittest.TestCase): + def test_backbone(self): + MMDetBackbone( + backbone=dict( + type="DetectoRS_ResNet", + conv_cfg=dict(type="ConvAWS"), + sac=dict(type="SAC", use_deform=True), + stage_with_sac=(False, True, True, True), + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=True, + style="pytorch", + ), + neck=dict( + type="FPN", + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + ), + # skip pretrained model for tests + # pretrained_backbone="torchvision://resnet50", + output_shapes=[ShapeSpec(channels=256, stride=s) for s in [4, 8, 16, 32, 64]], + output_names=["p2", "p3", "p4", "p5", "p6"], + ) + + def test_detector(self): + # a basic R50 Mask R-CNN + MMDetDetector( + detector=dict( + type="MaskRCNN", + backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=True, + style="pytorch", + # skip pretrained model for tests + # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')) + ), + neck=dict( + type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5 + ), + rpn_head=dict( + type="RPNHead", + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type="AnchorGenerator", + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + ), + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0], + ), + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type="L1Loss", loss_weight=1.0), + ), + roi_head=dict( + type="StandardRoIHead", + bbox_roi_extractor=dict( + type="SingleRoIExtractor", + roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + ), + bbox_head=dict( + type="Shared2FCBBoxHead", + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type="DeltaXYWHBBoxCoder", + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2], + ), + reg_class_agnostic=False, + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type="L1Loss", loss_weight=1.0), + ), + mask_roi_extractor=dict( + type="SingleRoIExtractor", + roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32], + ), + mask_head=dict( + type="FCNMaskHead", + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0), + ), + ), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type="MaxIoUAssigner", + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + ), + sampler=dict( + type="RandomSampler", + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + ), + allowed_border=-1, + pos_weight=-1, + debug=False, + ), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type="nms", iou_threshold=0.7), + min_bbox_size=0, + ), + rcnn=dict( + assigner=dict( + type="MaxIoUAssigner", + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + ), + sampler=dict( + type="RandomSampler", + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + ), + mask_size=28, + pos_weight=-1, + debug=False, + ), + ), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type="nms", iou_threshold=0.7), + min_bbox_size=0, + ), + rcnn=dict( + score_thr=0.05, + nms=dict(type="nms", iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5, + ), + ), + ), + pixel_mean=[1, 2, 3], + pixel_std=[1, 2, 3], + ) diff --git a/data_processing/detectron2/tests/modeling/test_model_e2e.py b/data_processing/detectron2/tests/modeling/test_model_e2e.py new file mode 100644 index 0000000..8c07e68 --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_model_e2e.py @@ -0,0 +1,227 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + + +import itertools +import unittest +from contextlib import contextmanager +from copy import deepcopy +import torch + +from detectron2.structures import BitMasks, Boxes, ImageList, Instances +from detectron2.utils.events import EventStorage +from detectron2.utils.testing import get_model_no_weights + + +@contextmanager +def typecheck_hook(model, *, in_dtype=None, out_dtype=None): + """ + Check that the model must be called with the given input/output dtype + """ + if not isinstance(in_dtype, set): + in_dtype = {in_dtype} + if not isinstance(out_dtype, set): + out_dtype = {out_dtype} + + def flatten(x): + if isinstance(x, torch.Tensor): + return [x] + if isinstance(x, (list, tuple)): + return list(itertools.chain(*[flatten(t) for t in x])) + if isinstance(x, dict): + return flatten(list(x.values())) + return [] + + def hook(module, input, output): + if in_dtype is not None: + dtypes = {x.dtype for x in flatten(input)} + assert ( + dtypes == in_dtype + ), f"Expected input dtype of {type(module)} is {in_dtype}. Got {dtypes} instead!" + + if out_dtype is not None: + dtypes = {x.dtype for x in flatten(output)} + assert ( + dtypes == out_dtype + ), f"Expected output dtype of {type(module)} is {out_dtype}. Got {dtypes} instead!" + + with model.register_forward_hook(hook): + yield + + +def create_model_input(img, inst=None): + if inst is not None: + return {"image": img, "instances": inst} + else: + return {"image": img} + + +def get_empty_instance(h, w): + inst = Instances((h, w)) + inst.gt_boxes = Boxes(torch.rand(0, 4)) + inst.gt_classes = torch.tensor([]).to(dtype=torch.int64) + inst.gt_masks = BitMasks(torch.rand(0, h, w)) + return inst + + +def get_regular_bitmask_instances(h, w): + inst = Instances((h, w)) + inst.gt_boxes = Boxes(torch.rand(3, 4)) + inst.gt_boxes.tensor[:, 2:] += inst.gt_boxes.tensor[:, :2] + inst.gt_classes = torch.tensor([3, 4, 5]).to(dtype=torch.int64) + inst.gt_masks = BitMasks((torch.rand(3, h, w) > 0.5)) + return inst + + +class InstanceModelE2ETest: + def setUp(self): + torch.manual_seed(43) + self.model = get_model_no_weights(self.CONFIG_PATH) + + def _test_eval(self, input_sizes): + inputs = [create_model_input(torch.rand(3, s[0], s[1])) for s in input_sizes] + self.model.eval() + self.model(inputs) + + def _test_train(self, input_sizes, instances): + assert len(input_sizes) == len(instances) + inputs = [ + create_model_input(torch.rand(3, s[0], s[1]), inst) + for s, inst in zip(input_sizes, instances) + ] + self.model.train() + with EventStorage(): + losses = self.model(inputs) + sum(losses.values()).backward() + del losses + + def _inf_tensor(self, *shape): + return 1.0 / torch.zeros(*shape, device=self.model.device) + + def _nan_tensor(self, *shape): + return torch.zeros(*shape, device=self.model.device).fill_(float("nan")) + + def test_empty_data(self): + instances = [get_empty_instance(200, 250), get_empty_instance(200, 249)] + self._test_eval([(200, 250), (200, 249)]) + self._test_train([(200, 250), (200, 249)], instances) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_eval_tocpu(self): + model = deepcopy(self.model).cpu() + model.eval() + input_sizes = [(200, 250), (200, 249)] + inputs = [create_model_input(torch.rand(3, s[0], s[1])) for s in input_sizes] + model(inputs) + + +class MaskRCNNE2ETest(InstanceModelE2ETest, unittest.TestCase): + CONFIG_PATH = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + + def test_half_empty_data(self): + instances = [get_empty_instance(200, 250), get_regular_bitmask_instances(200, 249)] + self._test_train([(200, 250), (200, 249)], instances) + + # This test is flaky because in some environment the output features are zero due to relu + # def test_rpn_inf_nan_data(self): + # self.model.eval() + # for tensor in [self._inf_tensor, self._nan_tensor]: + # images = ImageList(tensor(1, 3, 512, 512), [(510, 510)]) + # features = { + # "p2": tensor(1, 256, 256, 256), + # "p3": tensor(1, 256, 128, 128), + # "p4": tensor(1, 256, 64, 64), + # "p5": tensor(1, 256, 32, 32), + # "p6": tensor(1, 256, 16, 16), + # } + # props, _ = self.model.proposal_generator(images, features) + # self.assertEqual(len(props[0]), 0) + + def test_roiheads_inf_nan_data(self): + self.model.eval() + for tensor in [self._inf_tensor, self._nan_tensor]: + images = ImageList(tensor(1, 3, 512, 512), [(510, 510)]) + features = { + "p2": tensor(1, 256, 256, 256), + "p3": tensor(1, 256, 128, 128), + "p4": tensor(1, 256, 64, 64), + "p5": tensor(1, 256, 32, 32), + "p6": tensor(1, 256, 16, 16), + } + props = [Instances((510, 510))] + props[0].proposal_boxes = Boxes([[10, 10, 20, 20]]).to(device=self.model.device) + props[0].objectness_logits = torch.tensor([1.0]).reshape(1, 1) + det, _ = self.model.roi_heads(images, features, props) + self.assertEqual(len(det[0]), 0) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_autocast(self): + from torch.cuda.amp import autocast + + inputs = [{"image": torch.rand(3, 100, 100)}] + self.model.eval() + with autocast(), typecheck_hook( + self.model.backbone, in_dtype=torch.float32, out_dtype=torch.float16 + ), typecheck_hook( + self.model.roi_heads.box_predictor, in_dtype=torch.float16, out_dtype=torch.float16 + ): + out = self.model.inference(inputs, do_postprocess=False)[0] + self.assertEqual(out.pred_boxes.tensor.dtype, torch.float32) + self.assertEqual(out.pred_masks.dtype, torch.float16) + self.assertEqual(out.scores.dtype, torch.float32) # scores comes from softmax + + +class RetinaNetE2ETest(InstanceModelE2ETest, unittest.TestCase): + CONFIG_PATH = "COCO-Detection/retinanet_R_50_FPN_1x.yaml" + + def test_inf_nan_data(self): + self.model.eval() + self.model.score_threshold = -999999999 + for tensor in [self._inf_tensor, self._nan_tensor]: + images = ImageList(tensor(1, 3, 512, 512), [(510, 510)]) + features = [ + tensor(1, 256, 128, 128), + tensor(1, 256, 64, 64), + tensor(1, 256, 32, 32), + tensor(1, 256, 16, 16), + tensor(1, 256, 8, 8), + ] + pred_logits, pred_anchor_deltas = self.model.head(features) + pred_logits = [tensor(*x.shape) for x in pred_logits] + pred_anchor_deltas = [tensor(*x.shape) for x in pred_anchor_deltas] + det = self.model.forward_inference(images, features, [pred_logits, pred_anchor_deltas]) + # all predictions (if any) are infinite or nan + if len(det[0]): + self.assertTrue(torch.isfinite(det[0].pred_boxes.tensor).sum() == 0) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_autocast(self): + from torch.cuda.amp import autocast + + inputs = [{"image": torch.rand(3, 100, 100)}] + self.model.eval() + with autocast(), typecheck_hook( + self.model.backbone, in_dtype=torch.float32, out_dtype=torch.float16 + ), typecheck_hook(self.model.head, in_dtype=torch.float16, out_dtype=torch.float16): + out = self.model(inputs)[0]["instances"] + self.assertEqual(out.pred_boxes.tensor.dtype, torch.float32) + self.assertEqual(out.scores.dtype, torch.float16) + + +class FCOSE2ETest(InstanceModelE2ETest, unittest.TestCase): + CONFIG_PATH = "COCO-Detection/fcos_R_50_FPN_1x.py" + + +class SemSegE2ETest(unittest.TestCase): + CONFIG_PATH = "Misc/semantic_R_50_FPN_1x.yaml" + + def setUp(self): + torch.manual_seed(43) + self.model = get_model_no_weights(self.CONFIG_PATH) + + def _test_eval(self, input_sizes): + inputs = [create_model_input(torch.rand(3, s[0], s[1])) for s in input_sizes] + self.model.eval() + self.model(inputs) + + def test_forward(self): + self._test_eval([(200, 250), (200, 249)]) diff --git a/data_processing/detectron2/tests/modeling/test_roi_heads.py b/data_processing/detectron2/tests/modeling/test_roi_heads.py new file mode 100644 index 0000000..86360e1 --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_roi_heads.py @@ -0,0 +1,323 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +from copy import deepcopy +import torch +from torch import nn + +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.export.torchscript_patch import ( + freeze_training_mode, + patch_builtin_len, + patch_instances, +) +from detectron2.layers import ShapeSpec +from detectron2.modeling.proposal_generator.build import build_proposal_generator +from detectron2.modeling.roi_heads import ( + FastRCNNConvFCHead, + KRCNNConvDeconvUpsampleHead, + MaskRCNNConvUpsampleHead, + StandardROIHeads, + build_roi_heads, +) +from detectron2.projects import point_rend +from detectron2.structures import BitMasks, Boxes, ImageList, Instances, RotatedBoxes +from detectron2.utils.events import EventStorage +from detectron2.utils.testing import assert_instances_allclose, random_boxes + +logger = logging.getLogger(__name__) + +""" +Make sure the losses of ROIHeads/RPN do not change, to avoid +breaking the forward logic by mistake. +This relies on assumption that pytorch's RNG is stable. +""" + + +class ROIHeadsTest(unittest.TestCase): + def test_roi_heads(self): + torch.manual_seed(121) + cfg = get_cfg() + cfg.MODEL.ROI_BOX_HEAD.NAME = "FastRCNNConvFCHead" + cfg.MODEL.ROI_BOX_HEAD.NUM_FC = 2 + cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2" + cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10, 10, 5, 5) + cfg.MODEL.MASK_ON = True + num_images = 2 + images_tensor = torch.rand(num_images, 20, 30) + image_sizes = [(10, 10), (20, 30)] + images = ImageList(images_tensor, image_sizes) + num_channels = 1024 + features = {"res4": torch.rand(num_images, num_channels, 1, 2)} + feature_shape = {"res4": ShapeSpec(channels=num_channels, stride=16)} + + image_shape = (15, 15) + gt_boxes0 = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6]], dtype=torch.float32) + gt_instance0 = Instances(image_shape) + gt_instance0.gt_boxes = Boxes(gt_boxes0) + gt_instance0.gt_classes = torch.tensor([2, 1]) + gt_instance0.gt_masks = BitMasks(torch.rand((2,) + image_shape) > 0.5) + gt_boxes1 = torch.tensor([[1, 5, 2, 8], [7, 3, 10, 5]], dtype=torch.float32) + gt_instance1 = Instances(image_shape) + gt_instance1.gt_boxes = Boxes(gt_boxes1) + gt_instance1.gt_classes = torch.tensor([1, 2]) + gt_instance1.gt_masks = BitMasks(torch.rand((2,) + image_shape) > 0.5) + gt_instances = [gt_instance0, gt_instance1] + + proposal_generator = build_proposal_generator(cfg, feature_shape) + roi_heads = StandardROIHeads(cfg, feature_shape) + + with EventStorage(): # capture events in a new storage to discard them + proposals, proposal_losses = proposal_generator(images, features, gt_instances) + _, detector_losses = roi_heads(images, features, proposals, gt_instances) + + detector_losses.update(proposal_losses) + expected_losses = { + "loss_cls": 4.5253729820251465, + "loss_box_reg": 0.009785720147192478, + "loss_mask": 0.693184494972229, + "loss_rpn_cls": 0.08186662942171097, + "loss_rpn_loc": 0.1104838103055954, + } + succ = all( + torch.allclose(detector_losses[name], torch.tensor(expected_losses.get(name, 0.0))) + for name in detector_losses.keys() + ) + self.assertTrue( + succ, + "Losses has changed! New losses: {}".format( + {k: v.item() for k, v in detector_losses.items()} + ), + ) + + def test_rroi_heads(self): + torch.manual_seed(121) + cfg = get_cfg() + cfg.MODEL.PROPOSAL_GENERATOR.NAME = "RRPN" + cfg.MODEL.ANCHOR_GENERATOR.NAME = "RotatedAnchorGenerator" + cfg.MODEL.ROI_HEADS.NAME = "RROIHeads" + cfg.MODEL.ROI_BOX_HEAD.NAME = "FastRCNNConvFCHead" + cfg.MODEL.ROI_BOX_HEAD.NUM_FC = 2 + cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (1, 1, 1, 1, 1) + cfg.MODEL.RPN.HEAD_NAME = "StandardRPNHead" + cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignRotated" + cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10, 10, 5, 5, 1) + num_images = 2 + images_tensor = torch.rand(num_images, 20, 30) + image_sizes = [(10, 10), (20, 30)] + images = ImageList(images_tensor, image_sizes) + num_channels = 1024 + features = {"res4": torch.rand(num_images, num_channels, 1, 2)} + feature_shape = {"res4": ShapeSpec(channels=num_channels, stride=16)} + + image_shape = (15, 15) + gt_boxes0 = torch.tensor([[2, 2, 2, 2, 30], [4, 4, 4, 4, 0]], dtype=torch.float32) + gt_instance0 = Instances(image_shape) + gt_instance0.gt_boxes = RotatedBoxes(gt_boxes0) + gt_instance0.gt_classes = torch.tensor([2, 1]) + gt_boxes1 = torch.tensor([[1.5, 5.5, 1, 3, 0], [8.5, 4, 3, 2, -50]], dtype=torch.float32) + gt_instance1 = Instances(image_shape) + gt_instance1.gt_boxes = RotatedBoxes(gt_boxes1) + gt_instance1.gt_classes = torch.tensor([1, 2]) + gt_instances = [gt_instance0, gt_instance1] + + proposal_generator = build_proposal_generator(cfg, feature_shape) + roi_heads = build_roi_heads(cfg, feature_shape) + + with EventStorage(): # capture events in a new storage to discard them + proposals, proposal_losses = proposal_generator(images, features, gt_instances) + _, detector_losses = roi_heads(images, features, proposals, gt_instances) + + detector_losses.update(proposal_losses) + expected_losses = { + "loss_cls": 4.365657806396484, + "loss_box_reg": 0.0015851043863222003, + "loss_rpn_cls": 0.2427729219198227, + "loss_rpn_loc": 0.3646621108055115, + } + succ = all( + torch.allclose(detector_losses[name], torch.tensor(expected_losses.get(name, 0.0))) + for name in detector_losses.keys() + ) + self.assertTrue( + succ, + "Losses has changed! New losses: {}".format( + {k: v.item() for k, v in detector_losses.items()} + ), + ) + + def test_box_head_scriptability(self): + input_shape = ShapeSpec(channels=1024, height=14, width=14) + box_features = torch.randn(4, 1024, 14, 14) + + box_head = FastRCNNConvFCHead( + input_shape, conv_dims=[512, 512], fc_dims=[1024, 1024] + ).eval() + script_box_head = torch.jit.script(box_head) + + origin_output = box_head(box_features) + script_output = script_box_head(box_features) + self.assertTrue(torch.equal(origin_output, script_output)) + + def test_mask_head_scriptability(self): + input_shape = ShapeSpec(channels=1024) + mask_features = torch.randn(4, 1024, 14, 14) + + image_shapes = [(10, 10), (15, 15)] + pred_instance0 = Instances(image_shapes[0]) + pred_classes0 = torch.tensor([1, 2, 3], dtype=torch.int64) + pred_instance0.pred_classes = pred_classes0 + pred_instance1 = Instances(image_shapes[1]) + pred_classes1 = torch.tensor([4], dtype=torch.int64) + pred_instance1.pred_classes = pred_classes1 + + mask_head = MaskRCNNConvUpsampleHead( + input_shape, num_classes=80, conv_dims=[256, 256] + ).eval() + # pred_instance will be in-place changed during the inference + # process of `MaskRCNNConvUpsampleHead` + origin_outputs = mask_head(mask_features, deepcopy([pred_instance0, pred_instance1])) + + fields = {"pred_masks": torch.Tensor, "pred_classes": torch.Tensor} + with freeze_training_mode(mask_head), patch_instances(fields) as NewInstances: + sciript_mask_head = torch.jit.script(mask_head) + pred_instance0 = NewInstances.from_instances(pred_instance0) + pred_instance1 = NewInstances.from_instances(pred_instance1) + script_outputs = sciript_mask_head(mask_features, [pred_instance0, pred_instance1]) + + for origin_ins, script_ins in zip(origin_outputs, script_outputs): + assert_instances_allclose(origin_ins, script_ins, rtol=0) + + def test_keypoint_head_scriptability(self): + input_shape = ShapeSpec(channels=1024, height=14, width=14) + keypoint_features = torch.randn(4, 1024, 14, 14) + + image_shapes = [(10, 10), (15, 15)] + pred_boxes0 = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6], [1, 5, 2, 8]], dtype=torch.float32) + pred_instance0 = Instances(image_shapes[0]) + pred_instance0.pred_boxes = Boxes(pred_boxes0) + pred_boxes1 = torch.tensor([[7, 3, 10, 5]], dtype=torch.float32) + pred_instance1 = Instances(image_shapes[1]) + pred_instance1.pred_boxes = Boxes(pred_boxes1) + + keypoint_head = KRCNNConvDeconvUpsampleHead( + input_shape, num_keypoints=17, conv_dims=[512, 512] + ).eval() + origin_outputs = keypoint_head( + keypoint_features, deepcopy([pred_instance0, pred_instance1]) + ) + + fields = { + "pred_boxes": Boxes, + "pred_keypoints": torch.Tensor, + "pred_keypoint_heatmaps": torch.Tensor, + } + with freeze_training_mode(keypoint_head), patch_instances(fields) as NewInstances: + script_keypoint_head = torch.jit.script(keypoint_head) + pred_instance0 = NewInstances.from_instances(pred_instance0) + pred_instance1 = NewInstances.from_instances(pred_instance1) + script_outputs = script_keypoint_head( + keypoint_features, [pred_instance0, pred_instance1] + ) + + for origin_ins, script_ins in zip(origin_outputs, script_outputs): + assert_instances_allclose(origin_ins, script_ins, rtol=0) + + def test_StandardROIHeads_scriptability(self): + cfg = get_cfg() + cfg.MODEL.ROI_BOX_HEAD.NAME = "FastRCNNConvFCHead" + cfg.MODEL.ROI_BOX_HEAD.NUM_FC = 2 + cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2" + cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10, 10, 5, 5) + cfg.MODEL.MASK_ON = True + cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.01 + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.01 + num_images = 2 + images_tensor = torch.rand(num_images, 20, 30) + image_sizes = [(10, 10), (20, 30)] + images = ImageList(images_tensor, image_sizes) + num_channels = 1024 + features = {"res4": torch.rand(num_images, num_channels, 1, 2)} + feature_shape = {"res4": ShapeSpec(channels=num_channels, stride=16)} + + roi_heads = StandardROIHeads(cfg, feature_shape).eval() + + proposal0 = Instances(image_sizes[0]) + proposal_boxes0 = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6]], dtype=torch.float32) + proposal0.proposal_boxes = Boxes(proposal_boxes0) + proposal0.objectness_logits = torch.tensor([0.5, 0.7], dtype=torch.float32) + + proposal1 = Instances(image_sizes[1]) + proposal_boxes1 = torch.tensor([[1, 5, 2, 8], [7, 3, 10, 5]], dtype=torch.float32) + proposal1.proposal_boxes = Boxes(proposal_boxes1) + proposal1.objectness_logits = torch.tensor([0.1, 0.9], dtype=torch.float32) + proposals = [proposal0, proposal1] + + pred_instances, _ = roi_heads(images, features, proposals) + fields = { + "objectness_logits": torch.Tensor, + "proposal_boxes": Boxes, + "pred_classes": torch.Tensor, + "scores": torch.Tensor, + "pred_masks": torch.Tensor, + "pred_boxes": Boxes, + "pred_keypoints": torch.Tensor, + "pred_keypoint_heatmaps": torch.Tensor, + } + with freeze_training_mode(roi_heads), patch_instances(fields) as new_instances: + proposal0 = new_instances.from_instances(proposal0) + proposal1 = new_instances.from_instances(proposal1) + proposals = [proposal0, proposal1] + scripted_rot_heads = torch.jit.script(roi_heads) + scripted_pred_instances, _ = scripted_rot_heads(images, features, proposals) + + for instance, scripted_instance in zip(pred_instances, scripted_pred_instances): + assert_instances_allclose(instance, scripted_instance, rtol=0) + + def test_PointRend_mask_head_tracing(self): + cfg = model_zoo.get_config("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml") + point_rend.add_pointrend_config(cfg) + cfg.MODEL.ROI_HEADS.IN_FEATURES = ["p2", "p3"] + cfg.MODEL.ROI_MASK_HEAD.NAME = "PointRendMaskHead" + cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "" + cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = True + chan = 256 + head = point_rend.PointRendMaskHead( + cfg, + { + "p2": ShapeSpec(channels=chan, stride=4), + "p3": ShapeSpec(channels=chan, stride=8), + }, + ) + + def gen_inputs(h, w, N): + p2 = torch.rand(1, chan, h, w) + p3 = torch.rand(1, chan, h // 2, w // 2) + boxes = random_boxes(N, max_coord=h) + return p2, p3, boxes + + class Wrap(nn.ModuleDict): + def forward(self, p2, p3, boxes): + features = { + "p2": p2, + "p3": p3, + } + inst = Instances((p2.shape[2] * 4, p2.shape[3] * 4)) + inst.pred_boxes = Boxes(boxes) + inst.pred_classes = torch.zeros(inst.__len__(), dtype=torch.long) + out = self.head(features, [inst])[0] + return out.pred_masks + + model = Wrap({"head": head}) + model.eval() + with torch.no_grad(), patch_builtin_len(): + traced = torch.jit.trace(model, gen_inputs(302, 208, 20)) + inputs = gen_inputs(100, 120, 30) + out_eager = model(*inputs) + out_trace = traced(*inputs) + self.assertTrue(torch.allclose(out_eager, out_trace)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_roi_pooler.py b/data_processing/detectron2/tests/modeling/test_roi_pooler.py new file mode 100644 index 0000000..e1d7c1c --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_roi_pooler.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import torch + +from detectron2.modeling.poolers import ROIPooler +from detectron2.structures import Boxes, RotatedBoxes +from detectron2.utils.testing import random_boxes + +logger = logging.getLogger(__name__) + + +class TestROIPooler(unittest.TestCase): + def _test_roialignv2_roialignrotated_match(self, device): + pooler_resolution = 14 + canonical_level = 4 + canonical_scale_factor = 2**canonical_level + pooler_scales = (1.0 / canonical_scale_factor,) + sampling_ratio = 0 + + N, C, H, W = 2, 4, 10, 8 + N_rois = 10 + std = 11 + mean = 0 + feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean + + features = [feature.to(device)] + + rois = [] + rois_rotated = [] + for _ in range(N): + boxes = random_boxes(N_rois, W * canonical_scale_factor) + rotated_boxes = torch.zeros(N_rois, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + rois.append(Boxes(boxes).to(device)) + rois_rotated.append(RotatedBoxes(rotated_boxes).to(device)) + + roialignv2_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type="ROIAlignV2", + ) + + roialignv2_out = roialignv2_pooler(features, rois) + + roialignrotated_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type="ROIAlignRotated", + ) + + roialignrotated_out = roialignrotated_pooler(features, rois_rotated) + + self.assertTrue(torch.allclose(roialignv2_out, roialignrotated_out, atol=1e-4)) + + def test_roialignv2_roialignrotated_match_cpu(self): + self._test_roialignv2_roialignrotated_match(device="cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_roialignv2_roialignrotated_match_cuda(self): + self._test_roialignv2_roialignrotated_match(device="cuda") + + def _test_scriptability(self, device): + pooler_resolution = 14 + canonical_level = 4 + canonical_scale_factor = 2**canonical_level + pooler_scales = (1.0 / canonical_scale_factor,) + sampling_ratio = 0 + + N, C, H, W = 2, 4, 10, 8 + N_rois = 10 + std = 11 + mean = 0 + feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean + + features = [feature.to(device)] + + rois = [] + for _ in range(N): + boxes = random_boxes(N_rois, W * canonical_scale_factor) + + rois.append(Boxes(boxes).to(device)) + + roialignv2_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type="ROIAlignV2", + ) + + roialignv2_out = roialignv2_pooler(features, rois) + scripted_roialignv2_out = torch.jit.script(roialignv2_pooler)(features, rois) + self.assertTrue(torch.equal(roialignv2_out, scripted_roialignv2_out)) + + def test_scriptability_cpu(self): + self._test_scriptability(device="cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_scriptability_gpu(self): + self._test_scriptability(device="cuda") + + def test_no_images(self): + N, C, H, W = 0, 32, 32, 32 + feature = torch.rand(N, C, H, W) - 0.5 + features = [feature] + pooler = ROIPooler( + output_size=14, scales=(1.0,), sampling_ratio=0.0, pooler_type="ROIAlignV2" + ) + output = pooler.forward(features, []) + self.assertEqual(output.shape, (0, C, 14, 14)) + + def test_roi_pooler_tracing(self): + class Model(torch.nn.Module): + def __init__(self, roi): + super(Model, self).__init__() + self.roi = roi + + def forward(self, x, boxes): + return self.roi(x, [Boxes(boxes)]) + + pooler_resolution = 14 + canonical_level = 4 + canonical_scale_factor = 2**canonical_level + pooler_scales = (1.0 / canonical_scale_factor, 0.5 / canonical_scale_factor) + sampling_ratio = 0 + + N, C, H, W = 1, 4, 10, 8 + N_rois = 10 + std = 11 + mean = 0 + feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean + feature = [feature, feature] + + rois = random_boxes(N_rois, W * canonical_scale_factor) + # Add one larger box so that this level has only one box. + # This may trigger the bug https://github.com/pytorch/pytorch/issues/49852 + # that we shall workaround. + rois = torch.cat([rois, torch.tensor([[0, 0, 448, 448]])]) + + model = Model( + ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type="ROIAlign", + ) + ) + + with torch.no_grad(): + func = torch.jit.trace(model, (feature, rois)) + o = func(feature, rois) + self.assertEqual(o.shape, (11, 4, 14, 14)) + o = func(feature, rois[:5]) + self.assertEqual(o.shape, (5, 4, 14, 14)) + o = func(feature, random_boxes(20, W * canonical_scale_factor)) + self.assertEqual(o.shape, (20, 4, 14, 14)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/modeling/test_rpn.py b/data_processing/detectron2/tests/modeling/test_rpn.py new file mode 100644 index 0000000..f14faae --- /dev/null +++ b/data_processing/detectron2/tests/modeling/test_rpn.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest +import torch + +from detectron2.config import get_cfg +from detectron2.export import scripting_with_instances +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone import build_backbone +from detectron2.modeling.proposal_generator import RPN, build_proposal_generator +from detectron2.modeling.proposal_generator.proposal_utils import ( + add_ground_truth_to_proposals, + find_top_rpn_proposals, +) +from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes +from detectron2.utils.events import EventStorage + +logger = logging.getLogger(__name__) + + +class RPNTest(unittest.TestCase): + def get_gt_and_features(self): + num_images = 2 + images_tensor = torch.rand(num_images, 20, 30) + image_sizes = [(10, 10), (20, 30)] + images = ImageList(images_tensor, image_sizes) + image_shape = (15, 15) + num_channels = 1024 + features = {"res4": torch.rand(num_images, num_channels, 1, 2)} + gt_boxes = torch.tensor([[1, 1, 3, 3], [2, 2, 6, 6]], dtype=torch.float32) + gt_instances = Instances(image_shape) + gt_instances.gt_boxes = Boxes(gt_boxes) + return (gt_instances, features, images, image_sizes) + + def test_rpn(self): + torch.manual_seed(121) + cfg = get_cfg() + backbone = build_backbone(cfg) + proposal_generator = RPN(cfg, backbone.output_shape()) + (gt_instances, features, images, image_sizes) = self.get_gt_and_features() + with EventStorage(): # capture events in a new storage to discard them + proposals, proposal_losses = proposal_generator( + images, features, [gt_instances[0], gt_instances[1]] + ) + + expected_losses = { + "loss_rpn_cls": torch.tensor(0.08011703193), + "loss_rpn_loc": torch.tensor(0.101470276), + } + for name in expected_losses.keys(): + err_msg = "proposal_losses[{}] = {}, expected losses = {}".format( + name, proposal_losses[name], expected_losses[name] + ) + self.assertTrue(torch.allclose(proposal_losses[name], expected_losses[name]), err_msg) + + self.assertEqual(len(proposals), len(image_sizes)) + for proposal, im_size in zip(proposals, image_sizes): + self.assertEqual(proposal.image_size, im_size) + + expected_proposal_box = torch.tensor([[0, 0, 10, 10], [7.2702, 0, 10, 10]]) + expected_objectness_logit = torch.tensor([0.1596, -0.0007]) + self.assertTrue( + torch.allclose(proposals[0].proposal_boxes.tensor, expected_proposal_box, atol=1e-4) + ) + self.assertTrue( + torch.allclose(proposals[0].objectness_logits, expected_objectness_logit, atol=1e-4) + ) + + def verify_rpn(self, conv_dims, expected_conv_dims): + torch.manual_seed(121) + cfg = get_cfg() + cfg.MODEL.RPN.CONV_DIMS = conv_dims + backbone = build_backbone(cfg) + proposal_generator = RPN(cfg, backbone.output_shape()) + for k, conv in enumerate(proposal_generator.rpn_head.conv): + self.assertEqual(expected_conv_dims[k], conv.out_channels) + return proposal_generator + + def test_rpn_larger_num_convs(self): + conv_dims = [64, 64, 64, 64, 64] + proposal_generator = self.verify_rpn(conv_dims, conv_dims) + (gt_instances, features, images, image_sizes) = self.get_gt_and_features() + with EventStorage(): # capture events in a new storage to discard them + proposals, proposal_losses = proposal_generator( + images, features, [gt_instances[0], gt_instances[1]] + ) + expected_losses = { + "loss_rpn_cls": torch.tensor(0.08122821152), + "loss_rpn_loc": torch.tensor(0.10064548254), + } + for name in expected_losses.keys(): + err_msg = "proposal_losses[{}] = {}, expected losses = {}".format( + name, proposal_losses[name], expected_losses[name] + ) + self.assertTrue(torch.allclose(proposal_losses[name], expected_losses[name]), err_msg) + + def test_rpn_conv_dims_not_set(self): + conv_dims = [-1, -1, -1] + expected_conv_dims = [1024, 1024, 1024] + self.verify_rpn(conv_dims, expected_conv_dims) + + def test_rpn_scriptability(self): + cfg = get_cfg() + proposal_generator = RPN(cfg, {"res4": ShapeSpec(channels=1024, stride=16)}).eval() + num_images = 2 + images_tensor = torch.rand(num_images, 30, 40) + image_sizes = [(32, 32), (30, 40)] + images = ImageList(images_tensor, image_sizes) + features = {"res4": torch.rand(num_images, 1024, 1, 2)} + + fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} + proposal_generator_ts = scripting_with_instances(proposal_generator, fields) + + proposals, _ = proposal_generator(images, features) + proposals_ts, _ = proposal_generator_ts(images, features) + + for proposal, proposal_ts in zip(proposals, proposals_ts): + self.assertEqual(proposal.image_size, proposal_ts.image_size) + self.assertTrue( + torch.equal(proposal.proposal_boxes.tensor, proposal_ts.proposal_boxes.tensor) + ) + self.assertTrue(torch.equal(proposal.objectness_logits, proposal_ts.objectness_logits)) + + def test_rrpn(self): + torch.manual_seed(121) + cfg = get_cfg() + cfg.MODEL.PROPOSAL_GENERATOR.NAME = "RRPN" + cfg.MODEL.ANCHOR_GENERATOR.NAME = "RotatedAnchorGenerator" + cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64]] + cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.25, 1]] + cfg.MODEL.ANCHOR_GENERATOR.ANGLES = [[0, 60]] + cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (1, 1, 1, 1, 1) + cfg.MODEL.RPN.HEAD_NAME = "StandardRPNHead" + backbone = build_backbone(cfg) + proposal_generator = build_proposal_generator(cfg, backbone.output_shape()) + num_images = 2 + images_tensor = torch.rand(num_images, 20, 30) + image_sizes = [(10, 10), (20, 30)] + images = ImageList(images_tensor, image_sizes) + image_shape = (15, 15) + num_channels = 1024 + features = {"res4": torch.rand(num_images, num_channels, 1, 2)} + gt_boxes = torch.tensor([[2, 2, 2, 2, 0], [4, 4, 4, 4, 0]], dtype=torch.float32) + gt_instances = Instances(image_shape) + gt_instances.gt_boxes = RotatedBoxes(gt_boxes) + with EventStorage(): # capture events in a new storage to discard them + proposals, proposal_losses = proposal_generator( + images, features, [gt_instances[0], gt_instances[1]] + ) + + expected_losses = { + "loss_rpn_cls": torch.tensor(0.04291602224), + "loss_rpn_loc": torch.tensor(0.145077362), + } + for name in expected_losses.keys(): + err_msg = "proposal_losses[{}] = {}, expected losses = {}".format( + name, proposal_losses[name], expected_losses[name] + ) + self.assertTrue(torch.allclose(proposal_losses[name], expected_losses[name]), err_msg) + + expected_proposal_box = torch.tensor( + [ + [-1.77999556, 0.78155339, 68.04367828, 14.78156471, 60.59333801], + [13.82740974, -1.50282836, 34.67269897, 29.19676590, -3.81942749], + [8.10392570, -0.99071521, 145.39100647, 32.13126373, 3.67242432], + [5.00000000, 4.57370186, 10.00000000, 9.14740372, 0.89196777], + ] + ) + + expected_objectness_logit = torch.tensor([0.10924313, 0.09881870, 0.07649877, 0.05858029]) + + torch.set_printoptions(precision=8, sci_mode=False) + + self.assertEqual(len(proposals), len(image_sizes)) + + proposal = proposals[0] + # It seems that there's some randomness in the result across different machines: + # This test can be run on a local machine for 100 times with exactly the same result, + # However, a different machine might produce slightly different results, + # thus the atol here. + err_msg = "computed proposal boxes = {}, expected {}".format( + proposal.proposal_boxes.tensor, expected_proposal_box + ) + self.assertTrue( + torch.allclose(proposal.proposal_boxes.tensor[:4], expected_proposal_box, atol=1e-5), + err_msg, + ) + + err_msg = "computed objectness logits = {}, expected {}".format( + proposal.objectness_logits, expected_objectness_logit + ) + self.assertTrue( + torch.allclose(proposal.objectness_logits[:4], expected_objectness_logit, atol=1e-5), + err_msg, + ) + + def test_find_rpn_proposals_inf(self): + N, Hi, Wi, A = 3, 3, 3, 3 + proposals = [torch.rand(N, Hi * Wi * A, 4)] + pred_logits = [torch.rand(N, Hi * Wi * A)] + pred_logits[0][1][3:5].fill_(float("inf")) + find_top_rpn_proposals(proposals, pred_logits, [(10, 10)], 0.5, 1000, 1000, 0, False) + + def test_find_rpn_proposals_tracing(self): + N, Hi, Wi, A = 3, 50, 50, 9 + proposal = torch.rand(N, Hi * Wi * A, 4) + pred_logit = torch.rand(N, Hi * Wi * A) + + def func(proposal, logit, image_size): + r = find_top_rpn_proposals( + [proposal], [logit], [image_size], 0.7, 1000, 1000, 0, False + )[0] + size = r.image_size + if not isinstance(size, torch.Tensor): + size = torch.tensor(size) + return (size, r.proposal_boxes.tensor, r.objectness_logits) + + other_inputs = [] + # test that it generalizes to other shapes + for Hi, Wi, shp in [(30, 30, 60), (10, 10, 800)]: + other_inputs.append( + ( + torch.rand(N, Hi * Wi * A, 4), + torch.rand(N, Hi * Wi * A), + torch.tensor([shp, shp]), + ) + ) + torch.jit.trace( + func, (proposal, pred_logit, torch.tensor([100, 100])), check_inputs=other_inputs + ) + + def test_append_gt_to_proposal(self): + proposals = Instances( + (10, 10), + **{ + "proposal_boxes": Boxes(torch.empty((0, 4))), + "objectness_logits": torch.tensor([]), + "custom_attribute": torch.tensor([]), + } + ) + gt_boxes = Boxes(torch.tensor([[0, 0, 1, 1]])) + + self.assertRaises(AssertionError, add_ground_truth_to_proposals, [gt_boxes], [proposals]) + + gt_instances = Instances((10, 10)) + gt_instances.gt_boxes = gt_boxes + + self.assertRaises( + AssertionError, add_ground_truth_to_proposals, [gt_instances], [proposals] + ) + + gt_instances.custom_attribute = torch.tensor([1]) + gt_instances.custom_attribute2 = torch.tensor([1]) + new_proposals = add_ground_truth_to_proposals([gt_instances], [proposals])[0] + + self.assertEqual(new_proposals.custom_attribute[0], 1) + # new proposals should only include the attributes in proposals + self.assertRaises(AttributeError, lambda: new_proposals.custom_attribute2) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/__init__.py b/data_processing/detectron2/tests/structures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tests/structures/test_boxes.py b/data_processing/detectron2/tests/structures/test_boxes.py new file mode 100644 index 0000000..1011918 --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_boxes.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import math +import numpy as np +import unittest +import torch + +from detectron2.structures import Boxes, BoxMode, pairwise_ioa, pairwise_iou +from detectron2.utils.testing import reload_script_model + + +class TestBoxMode(unittest.TestCase): + def _convert_xy_to_wh(self, x): + return BoxMode.convert(x, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + + def _convert_xywha_to_xyxy(self, x): + return BoxMode.convert(x, BoxMode.XYWHA_ABS, BoxMode.XYXY_ABS) + + def _convert_xywh_to_xywha(self, x): + return BoxMode.convert(x, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS) + + def test_convert_int_mode(self): + BoxMode.convert([1, 2, 3, 4], 0, 1) + + def test_box_convert_list(self): + for tp in [list, tuple]: + box = tp([5.0, 5.0, 10.0, 10.0]) + output = self._convert_xy_to_wh(box) + self.assertIsInstance(output, tp) + self.assertIsInstance(output[0], float) + self.assertEqual(output, tp([5.0, 5.0, 5.0, 5.0])) + + with self.assertRaises(Exception): + self._convert_xy_to_wh([box]) + + def test_box_convert_array(self): + box = np.asarray([[5, 5, 10, 10], [1, 1, 2, 3]]) + output = self._convert_xy_to_wh(box) + self.assertEqual(output.dtype, box.dtype) + self.assertEqual(output.shape, box.shape) + self.assertTrue((output[0] == [5, 5, 5, 5]).all()) + self.assertTrue((output[1] == [1, 1, 1, 2]).all()) + + def test_box_convert_cpu_tensor(self): + box = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) + output = self._convert_xy_to_wh(box) + self.assertEqual(output.dtype, box.dtype) + self.assertEqual(output.shape, box.shape) + output = output.numpy() + self.assertTrue((output[0] == [5, 5, 5, 5]).all()) + self.assertTrue((output[1] == [1, 1, 1, 2]).all()) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_box_convert_cuda_tensor(self): + box = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]).cuda() + output = self._convert_xy_to_wh(box) + self.assertEqual(output.dtype, box.dtype) + self.assertEqual(output.shape, box.shape) + self.assertEqual(output.device, box.device) + output = output.cpu().numpy() + self.assertTrue((output[0] == [5, 5, 5, 5]).all()) + self.assertTrue((output[1] == [1, 1, 1, 2]).all()) + + def test_box_convert_xywha_to_xyxy_list(self): + for tp in [list, tuple]: + box = tp([50, 50, 30, 20, 0]) + output = self._convert_xywha_to_xyxy(box) + self.assertIsInstance(output, tp) + self.assertEqual(output, tp([35, 40, 65, 60])) + + with self.assertRaises(Exception): + self._convert_xywha_to_xyxy([box]) + + def test_box_convert_xywha_to_xyxy_array(self): + for dtype in [np.float64, np.float32]: + box = np.asarray( + [ + [50, 50, 30, 20, 0], + [50, 50, 30, 20, 90], + [1, 1, math.sqrt(2), math.sqrt(2), -45], + ], + dtype=dtype, + ) + output = self._convert_xywha_to_xyxy(box) + self.assertEqual(output.dtype, box.dtype) + expected = np.asarray([[35, 40, 65, 60], [40, 35, 60, 65], [0, 0, 2, 2]], dtype=dtype) + self.assertTrue(np.allclose(output, expected, atol=1e-6), "output={}".format(output)) + + def test_box_convert_xywha_to_xyxy_tensor(self): + for dtype in [torch.float32, torch.float64]: + box = torch.tensor( + [ + [50, 50, 30, 20, 0], + [50, 50, 30, 20, 90], + [1, 1, math.sqrt(2), math.sqrt(2), -45], + ], + dtype=dtype, + ) + output = self._convert_xywha_to_xyxy(box) + self.assertEqual(output.dtype, box.dtype) + expected = torch.tensor([[35, 40, 65, 60], [40, 35, 60, 65], [0, 0, 2, 2]], dtype=dtype) + + self.assertTrue(torch.allclose(output, expected, atol=1e-6), "output={}".format(output)) + + def test_box_convert_xywh_to_xywha_list(self): + for tp in [list, tuple]: + box = tp([50, 50, 30, 20]) + output = self._convert_xywh_to_xywha(box) + self.assertIsInstance(output, tp) + self.assertEqual(output, tp([65, 60, 30, 20, 0])) + + with self.assertRaises(Exception): + self._convert_xywh_to_xywha([box]) + + def test_box_convert_xywh_to_xywha_array(self): + for dtype in [np.float64, np.float32]: + box = np.asarray([[30, 40, 70, 60], [30, 40, 60, 70], [-1, -1, 2, 2]], dtype=dtype) + output = self._convert_xywh_to_xywha(box) + self.assertEqual(output.dtype, box.dtype) + expected = np.asarray( + [[65, 70, 70, 60, 0], [60, 75, 60, 70, 0], [0, 0, 2, 2, 0]], dtype=dtype + ) + self.assertTrue(np.allclose(output, expected, atol=1e-6), "output={}".format(output)) + + def test_box_convert_xywh_to_xywha_tensor(self): + for dtype in [torch.float32, torch.float64]: + box = torch.tensor([[30, 40, 70, 60], [30, 40, 60, 70], [-1, -1, 2, 2]], dtype=dtype) + output = self._convert_xywh_to_xywha(box) + self.assertEqual(output.dtype, box.dtype) + expected = torch.tensor( + [[65, 70, 70, 60, 0], [60, 75, 60, 70, 0], [0, 0, 2, 2, 0]], dtype=dtype + ) + + self.assertTrue(torch.allclose(output, expected, atol=1e-6), "output={}".format(output)) + + def test_json_serializable(self): + payload = {"box_mode": BoxMode.XYWH_REL} + try: + json.dumps(payload) + except Exception: + self.fail("JSON serialization failed") + + def test_json_deserializable(self): + payload = '{"box_mode": 2}' + obj = json.loads(payload) + try: + obj["box_mode"] = BoxMode(obj["box_mode"]) + except Exception: + self.fail("JSON deserialization failed") + + +class TestBoxIOU(unittest.TestCase): + def create_boxes(self): + boxes1 = torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) + + boxes2 = torch.tensor( + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.5, 1.0], + [0.0, 0.0, 1.0, 0.5], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 1.0, 1.0], + [0.5, 0.5, 1.5, 1.5], + ] + ) + return boxes1, boxes2 + + def test_pairwise_iou(self): + boxes1, boxes2 = self.create_boxes() + expected_ious = torch.tensor( + [ + [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)], + [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)], + ] + ) + + ious = pairwise_iou(Boxes(boxes1), Boxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_ioa(self): + boxes1, boxes2 = self.create_boxes() + expected_ioas = torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 0.25], [1.0, 1.0, 1.0, 1.0, 1.0, 0.25]] + ) + ioas = pairwise_ioa(Boxes(boxes1), Boxes(boxes2)) + self.assertTrue(torch.allclose(ioas, expected_ioas)) + + +class TestBoxes(unittest.TestCase): + def test_empty_cat(self): + x = Boxes.cat([]) + self.assertTrue(x.tensor.shape, (0, 4)) + + def test_to(self): + x = Boxes(torch.rand(3, 4)) + self.assertEqual(x.to(device="cpu").tensor.device.type, "cpu") + + def test_scriptability(self): + def func(x): + boxes = Boxes(x) + test = boxes.to(torch.device("cpu")).tensor + return boxes.area(), test + + f = torch.jit.script(func) + f = reload_script_model(f) + f(torch.rand((3, 4))) + + data = torch.rand((3, 4)) + + def func_cat(x: torch.Tensor): + boxes1 = Boxes(x) + boxes2 = Boxes(x) + # boxes3 = Boxes.cat([boxes1, boxes2]) # this is not supported by torchsript for now. + boxes3 = boxes1.cat([boxes1, boxes2]) + return boxes3 + + f = torch.jit.script(func_cat) + script_box = f(data) + self.assertTrue(torch.equal(torch.cat([data, data]), script_box.tensor)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/test_imagelist.py b/data_processing/detectron2/tests/structures/test_imagelist.py new file mode 100644 index 0000000..e446e44 --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_imagelist.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest +from typing import List, Sequence, Tuple +import torch + +from detectron2.structures import ImageList + + +class TestImageList(unittest.TestCase): + def test_imagelist_padding_tracing(self): + # test that the trace does not contain hard-coded constant sizes + def to_imagelist(tensors: Sequence[torch.Tensor]): + image_list = ImageList.from_tensors(tensors, 4) + return image_list.tensor, image_list.image_sizes + + def _tensor(*shape): + return torch.ones(shape, dtype=torch.float32) + + # test CHW (inputs needs padding vs. no padding) + for shape in [(3, 10, 10), (3, 12, 12)]: + func = torch.jit.trace(to_imagelist, ([_tensor(*shape)],)) + tensor, image_sizes = func([_tensor(3, 15, 20)]) + self.assertEqual(tensor.shape, (1, 3, 16, 20), tensor.shape) + self.assertEqual(image_sizes[0].tolist(), [15, 20], image_sizes[0]) + + # test HW + func = torch.jit.trace(to_imagelist, ([_tensor(10, 10)],)) + tensor, image_sizes = func([_tensor(15, 20)]) + self.assertEqual(tensor.shape, (1, 16, 20), tensor.shape) + self.assertEqual(image_sizes[0].tolist(), [15, 20], image_sizes[0]) + + # test 2x CHW + func = torch.jit.trace( + to_imagelist, + ([_tensor(3, 16, 10), _tensor(3, 13, 11)],), + ) + tensor, image_sizes = func([_tensor(3, 25, 20), _tensor(3, 10, 10)]) + self.assertEqual(tensor.shape, (2, 3, 28, 20), tensor.shape) + self.assertEqual(image_sizes[0].tolist(), [25, 20], image_sizes[0]) + self.assertEqual(image_sizes[1].tolist(), [10, 10], image_sizes[1]) + # support calling with different spatial sizes, but not with different #images + + def test_imagelist_scriptability(self): + image_nums = 2 + image_tensor = torch.randn((image_nums, 10, 20), dtype=torch.float32) + image_shape = [(10, 20)] * image_nums + + def f(image_tensor, image_shape: List[Tuple[int, int]]): + return ImageList(image_tensor, image_shape) + + ret = f(image_tensor, image_shape) + ret_script = torch.jit.script(f)(image_tensor, image_shape) + + self.assertEqual(len(ret), len(ret_script)) + for i in range(image_nums): + self.assertTrue(torch.equal(ret[i], ret_script[i])) + + def test_imagelist_from_tensors_scriptability(self): + image_tensor_0 = torch.randn(10, 20, dtype=torch.float32) + image_tensor_1 = torch.randn(12, 22, dtype=torch.float32) + inputs = [image_tensor_0, image_tensor_1] + + def f(image_tensor: List[torch.Tensor]): + return ImageList.from_tensors(image_tensor, 10) + + ret = f(inputs) + ret_script = torch.jit.script(f)(inputs) + + self.assertEqual(len(ret), len(ret_script)) + self.assertTrue(torch.equal(ret.tensor, ret_script.tensor)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/test_instances.py b/data_processing/detectron2/tests/structures/test_instances.py new file mode 100644 index 0000000..a352f74 --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_instances.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest +import torch +from torch import Tensor + +from detectron2.export.torchscript import patch_instances +from detectron2.structures import Boxes, Instances +from detectron2.utils.testing import convert_scripted_instances + + +class TestInstances(unittest.TestCase): + def test_int_indexing(self): + attr1 = torch.tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 0.5], [0.0, 0.0, 1.0], [0.0, 0.5, 0.5]]) + attr2 = torch.tensor([0.1, 0.2, 0.3, 0.4]) + instances = Instances((100, 100)) + instances.attr1 = attr1 + instances.attr2 = attr2 + for i in range(-len(instances), len(instances)): + inst = instances[i] + self.assertEqual((inst.attr1 == attr1[i]).all(), True) + self.assertEqual((inst.attr2 == attr2[i]).all(), True) + + self.assertRaises(IndexError, lambda: instances[len(instances)]) + self.assertRaises(IndexError, lambda: instances[-len(instances) - 1]) + + def test_script_new_fields(self): + def get_mask(x: Instances) -> torch.Tensor: + return x.mask + + class f(torch.nn.Module): + def forward(self, x: Instances): + proposal_boxes = x.proposal_boxes # noqa F841 + objectness_logits = x.objectness_logits # noqa F841 + return x + + class g(torch.nn.Module): + def forward(self, x: Instances): + return get_mask(x) + + class g2(torch.nn.Module): + def __init__(self): + super().__init__() + self.g = g() + + def forward(self, x: Instances): + proposal_boxes = x.proposal_boxes # noqa F841 + return x, self.g(x) + + fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor} + with patch_instances(fields): + torch.jit.script(f()) + + # can't script anymore after exiting the context + with self.assertRaises(Exception): + # will create a ConcreteType for g + torch.jit.script(g2()) + + new_fields = {"mask": Tensor} + with patch_instances(new_fields): + # will compile g with a different Instances; this should pass + torch.jit.script(g()) + with self.assertRaises(Exception): + torch.jit.script(g2()) + + new_fields = {"mask": Tensor, "proposal_boxes": Boxes} + with patch_instances(new_fields) as NewInstances: + # get_mask will be compiled with a different Instances; this should pass + scripted_g2 = torch.jit.script(g2()) + x = NewInstances((3, 4)) + x.mask = torch.rand(3) + x.proposal_boxes = Boxes(torch.rand(3, 4)) + scripted_g2(x) # it should accept the new Instances object and run successfully + + def test_script_access_fields(self): + class f(torch.nn.Module): + def forward(self, x: Instances): + proposal_boxes = x.proposal_boxes + objectness_logits = x.objectness_logits + return proposal_boxes.tensor + objectness_logits + + fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor} + with patch_instances(fields): + torch.jit.script(f()) + + def test_script_len(self): + class f(torch.nn.Module): + def forward(self, x: Instances): + return len(x) + + class g(torch.nn.Module): + def forward(self, x: Instances): + return len(x) + + image_shape = (15, 15) + + fields = {"proposal_boxes": Boxes} + with patch_instances(fields) as new_instance: + script_module = torch.jit.script(f()) + x = new_instance(image_shape) + with self.assertRaises(Exception): + script_module(x) + box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) + x.proposal_boxes = Boxes(box_tensors) + length = script_module(x) + self.assertEqual(length, 2) + + fields = {"objectness_logits": Tensor} + with patch_instances(fields) as new_instance: + script_module = torch.jit.script(g()) + x = new_instance(image_shape) + objectness_logits = torch.tensor([1.0]).reshape(1, 1) + x.objectness_logits = objectness_logits + length = script_module(x) + self.assertEqual(length, 1) + + def test_script_has(self): + class f(torch.nn.Module): + def forward(self, x: Instances): + return x.has("proposal_boxes") + + image_shape = (15, 15) + fields = {"proposal_boxes": Boxes} + with patch_instances(fields) as new_instance: + script_module = torch.jit.script(f()) + x = new_instance(image_shape) + self.assertFalse(script_module(x)) + + box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) + x.proposal_boxes = Boxes(box_tensors) + self.assertTrue(script_module(x)) + + def test_script_to(self): + class f(torch.nn.Module): + def forward(self, x: Instances): + return x.to(torch.device("cpu")) + + image_shape = (15, 15) + fields = {"proposal_boxes": Boxes, "a": Tensor} + with patch_instances(fields) as new_instance: + script_module = torch.jit.script(f()) + x = new_instance(image_shape) + script_module(x) + + box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) + x.proposal_boxes = Boxes(box_tensors) + x.a = box_tensors + script_module(x) + + def test_script_getitem(self): + class f(torch.nn.Module): + def forward(self, x: Instances, idx): + return x[idx] + + image_shape = (15, 15) + fields = {"proposal_boxes": Boxes, "a": Tensor} + inst = Instances(image_shape) + inst.proposal_boxes = Boxes(torch.rand(4, 4)) + inst.a = torch.rand(4, 10) + idx = torch.tensor([True, False, True, False]) + with patch_instances(fields) as new_instance: + script_module = torch.jit.script(f()) + + out = f()(inst, idx) + out_scripted = script_module(new_instance.from_instances(inst), idx) + self.assertTrue( + torch.equal(out.proposal_boxes.tensor, out_scripted.proposal_boxes.tensor) + ) + self.assertTrue(torch.equal(out.a, out_scripted.a)) + + def test_from_to_instances(self): + orig = Instances((30, 30)) + orig.proposal_boxes = Boxes(torch.rand(3, 4)) + + fields = {"proposal_boxes": Boxes, "a": Tensor} + with patch_instances(fields) as NewInstances: + # convert to NewInstances and back + new1 = NewInstances.from_instances(orig) + new2 = convert_scripted_instances(new1) + self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor)) + self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor)) + + def test_script_init_args(self): + def f(x: Tensor): + image_shape = (15, 15) + # __init__ can take arguments + inst = Instances(image_shape, a=x, proposal_boxes=Boxes(x)) + inst2 = Instances(image_shape, a=x) + return inst.a, inst2.a + + fields = {"proposal_boxes": Boxes, "a": Tensor} + with patch_instances(fields): + script_f = torch.jit.script(f) + x = torch.randn(3, 4) + outputs = script_f(x) + self.assertTrue(torch.equal(outputs[0], x)) + self.assertTrue(torch.equal(outputs[1], x)) + + def test_script_cat(self): + def f(x: Tensor): + image_shape = (15, 15) + # __init__ can take arguments + inst = Instances(image_shape, a=x) + inst2 = Instances(image_shape, a=x) + + inst3 = Instances(image_shape, proposal_boxes=Boxes(x)) + return inst.cat([inst, inst2]), inst3.cat([inst3, inst3]) + + fields = {"proposal_boxes": Boxes, "a": Tensor} + with patch_instances(fields): + script_f = torch.jit.script(f) + x = torch.randn(3, 4) + output, output2 = script_f(x) + self.assertTrue(torch.equal(output.a, torch.cat([x, x]))) + self.assertFalse(output.has("proposal_boxes")) + self.assertTrue(torch.equal(output2.proposal_boxes.tensor, torch.cat([x, x]))) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/test_keypoints.py b/data_processing/detectron2/tests/structures/test_keypoints.py new file mode 100644 index 0000000..adc616e --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_keypoints.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest +import torch + +from detectron2.structures.keypoints import Keypoints + + +class TestKeypoints(unittest.TestCase): + def test_cat_keypoints(self): + keypoints1 = Keypoints(torch.rand(2, 21, 3)) + keypoints2 = Keypoints(torch.rand(4, 21, 3)) + + cat_keypoints = keypoints1.cat([keypoints1, keypoints2]) + self.assertTrue(torch.all(cat_keypoints.tensor[:2] == keypoints1.tensor).item()) + self.assertTrue(torch.all(cat_keypoints.tensor[2:] == keypoints2.tensor).item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/test_masks.py b/data_processing/detectron2/tests/structures/test_masks.py new file mode 100644 index 0000000..7991eb0 --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_masks.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest +import torch + +from detectron2.structures.masks import BitMasks, PolygonMasks, polygons_to_bitmask + + +class TestBitMask(unittest.TestCase): + def test_get_bounding_box(self): + masks = torch.tensor( + [ + [ + [False, False, False, True], + [False, False, True, True], + [False, True, True, False], + [False, True, True, False], + ], + [ + [False, False, False, False], + [False, False, True, False], + [False, True, True, False], + [False, True, True, False], + ], + torch.zeros(4, 4), + ] + ) + bitmask = BitMasks(masks) + box_true = torch.tensor([[1, 0, 4, 4], [1, 1, 3, 4], [0, 0, 0, 0]], dtype=torch.float32) + box = bitmask.get_bounding_boxes() + self.assertTrue(torch.all(box.tensor == box_true).item()) + + for box in box_true: + poly = box[[0, 1, 2, 1, 2, 3, 0, 3]].numpy() + mask = polygons_to_bitmask([poly], 4, 4) + reconstruct_box = BitMasks(mask[None, :, :]).get_bounding_boxes()[0].tensor + self.assertTrue(torch.all(box == reconstruct_box).item()) + + reconstruct_box = PolygonMasks([[poly]]).get_bounding_boxes()[0].tensor + self.assertTrue(torch.all(box == reconstruct_box).item()) + + def test_from_empty_polygons(self): + masks = BitMasks.from_polygon_masks([], 100, 100) + self.assertEqual(masks.tensor.shape, (0, 100, 100)) + + def test_getitem(self): + masks = BitMasks(torch.ones(3, 10, 10)) + self.assertEqual(masks[1].tensor.shape, (1, 10, 10)) + self.assertEqual(masks[1:3].tensor.shape, (2, 10, 10)) + self.assertEqual(masks[torch.tensor([True, False, False])].tensor.shape, (1, 10, 10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/structures/test_rotated_boxes.py b/data_processing/detectron2/tests/structures/test_rotated_boxes.py new file mode 100644 index 0000000..478f034 --- /dev/null +++ b/data_processing/detectron2/tests/structures/test_rotated_boxes.py @@ -0,0 +1,441 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import, division, print_function, unicode_literals +import logging +import math +import random +import unittest +import torch +from fvcore.common.benchmark import benchmark + +from detectron2.layers.rotated_boxes import pairwise_iou_rotated +from detectron2.structures.boxes import Boxes +from detectron2.structures.rotated_boxes import RotatedBoxes, pairwise_iou +from detectron2.utils.testing import reload_script_model + +logger = logging.getLogger(__name__) + + +class TestRotatedBoxesLayer(unittest.TestCase): + def test_iou_0_dim_cpu(self): + boxes1 = torch.rand(0, 5, dtype=torch.float32) + boxes2 = torch.rand(10, 5, dtype=torch.float32) + expected_ious = torch.zeros(0, 10, dtype=torch.float32) + ious = pairwise_iou_rotated(boxes1, boxes2) + self.assertTrue(torch.allclose(ious, expected_ious)) + + boxes1 = torch.rand(10, 5, dtype=torch.float32) + boxes2 = torch.rand(0, 5, dtype=torch.float32) + expected_ious = torch.zeros(10, 0, dtype=torch.float32) + ious = pairwise_iou_rotated(boxes1, boxes2) + self.assertTrue(torch.allclose(ious, expected_ious)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_iou_0_dim_cuda(self): + boxes1 = torch.rand(0, 5, dtype=torch.float32) + boxes2 = torch.rand(10, 5, dtype=torch.float32) + expected_ious = torch.zeros(0, 10, dtype=torch.float32) + ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda()) + self.assertTrue(torch.allclose(ious_cuda.cpu(), expected_ious)) + + boxes1 = torch.rand(10, 5, dtype=torch.float32) + boxes2 = torch.rand(0, 5, dtype=torch.float32) + expected_ious = torch.zeros(10, 0, dtype=torch.float32) + ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda()) + self.assertTrue(torch.allclose(ious_cuda.cpu(), expected_ious)) + + def test_iou_half_overlap_cpu(self): + boxes1 = torch.tensor([[0.5, 0.5, 1.0, 1.0, 0.0]], dtype=torch.float32) + boxes2 = torch.tensor([[0.25, 0.5, 0.5, 1.0, 0.0]], dtype=torch.float32) + expected_ious = torch.tensor([[0.5]], dtype=torch.float32) + ious = pairwise_iou_rotated(boxes1, boxes2) + self.assertTrue(torch.allclose(ious, expected_ious)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_iou_half_overlap_cuda(self): + boxes1 = torch.tensor([[0.5, 0.5, 1.0, 1.0, 0.0]], dtype=torch.float32) + boxes2 = torch.tensor([[0.25, 0.5, 0.5, 1.0, 0.0]], dtype=torch.float32) + expected_ious = torch.tensor([[0.5]], dtype=torch.float32) + ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda()) + self.assertTrue(torch.allclose(ious_cuda.cpu(), expected_ious)) + + def test_iou_precision(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor([[565, 565, 10, 10.0, 0]], dtype=torch.float32, device=device) + boxes2 = torch.tensor([[565, 565, 10, 8.3, 0]], dtype=torch.float32, device=device) + iou = 8.3 / 10.0 + expected_ious = torch.tensor([[iou]], dtype=torch.float32) + ious = pairwise_iou_rotated(boxes1, boxes2) + self.assertTrue(torch.allclose(ious.cpu(), expected_ious)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_iou_too_many_boxes_cuda(self): + s1, s2 = 5, 1289035 + boxes1 = torch.zeros(s1, 5) + boxes2 = torch.zeros(s2, 5) + ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda()) + self.assertTupleEqual(tuple(ious_cuda.shape), (s1, s2)) + + def test_iou_extreme(self): + # Cause floating point issues in cuda kernels (#1266) + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor([[160.0, 153.0, 230.0, 23.0, -37.0]], device=device) + boxes2 = torch.tensor( + [ + [ + -1.117407639806935e17, + 1.3858420478349148e18, + 1000.0000610351562, + 1000.0000610351562, + 1612.0, + ] + ], + device=device, + ) + ious = pairwise_iou_rotated(boxes1, boxes2) + self.assertTrue(ious.min() >= 0, ious) + + def test_iou_issue_2154(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor( + [ + [ + 296.6620178222656, + 458.73883056640625, + 23.515729904174805, + 47.677001953125, + 0.08795166015625, + ] + ], + device=device, + ) + boxes2 = torch.tensor( + [[296.66201, 458.73882000000003, 23.51573, 47.67702, 0.087951]], + device=device, + ) + ious = pairwise_iou_rotated(boxes1, boxes2) + expected_ious = torch.tensor([[1.0]], dtype=torch.float32) + self.assertTrue(torch.allclose(ious.cpu(), expected_ious)) + + def test_iou_issue_2167(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor( + [ + [ + 2563.74462890625000000000, + 1436.79016113281250000000, + 2174.70336914062500000000, + 214.09500122070312500000, + 115.11834716796875000000, + ] + ], + device=device, + ) + boxes2 = torch.tensor( + [ + [ + 2563.74462890625000000000, + 1436.79028320312500000000, + 2174.70288085937500000000, + 214.09495544433593750000, + 115.11835479736328125000, + ] + ], + device=device, + ) + ious = pairwise_iou_rotated(boxes1, boxes2) + expected_ious = torch.tensor([[1.0]], dtype=torch.float32) + self.assertTrue(torch.allclose(ious.cpu(), expected_ious)) + + +class TestRotatedBoxesStructure(unittest.TestCase): + def test_clip_area_0_degree(self): + for _ in range(50): + num_boxes = 100 + boxes_5d = torch.zeros(num_boxes, 5) + boxes_5d[:, 0] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 1] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 2] = torch.FloatTensor(num_boxes).uniform_(0, 500) + boxes_5d[:, 3] = torch.FloatTensor(num_boxes).uniform_(0, 500) + # Convert from (x_ctr, y_ctr, w, h, 0) to (x1, y1, x2, y2) + boxes_4d = torch.zeros(num_boxes, 4) + boxes_4d[:, 0] = boxes_5d[:, 0] - boxes_5d[:, 2] / 2.0 + boxes_4d[:, 1] = boxes_5d[:, 1] - boxes_5d[:, 3] / 2.0 + boxes_4d[:, 2] = boxes_5d[:, 0] + boxes_5d[:, 2] / 2.0 + boxes_4d[:, 3] = boxes_5d[:, 1] + boxes_5d[:, 3] / 2.0 + + image_size = (500, 600) + test_boxes_4d = Boxes(boxes_4d) + test_boxes_5d = RotatedBoxes(boxes_5d) + # Before clip + areas_4d = test_boxes_4d.area() + areas_5d = test_boxes_5d.area() + self.assertTrue(torch.allclose(areas_4d, areas_5d, atol=1e-1, rtol=1e-5)) + # After clip + test_boxes_4d.clip(image_size) + test_boxes_5d.clip(image_size) + areas_4d = test_boxes_4d.area() + areas_5d = test_boxes_5d.area() + self.assertTrue(torch.allclose(areas_4d, areas_5d, atol=1e-1, rtol=1e-5)) + + def test_clip_area_arbitrary_angle(self): + num_boxes = 100 + boxes_5d = torch.zeros(num_boxes, 5) + boxes_5d[:, 0] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 1] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 2] = torch.FloatTensor(num_boxes).uniform_(0, 500) + boxes_5d[:, 3] = torch.FloatTensor(num_boxes).uniform_(0, 500) + boxes_5d[:, 4] = torch.FloatTensor(num_boxes).uniform_(-1800, 1800) + clip_angle_threshold = random.uniform(0, 180) + + image_size = (500, 600) + test_boxes_5d = RotatedBoxes(boxes_5d) + # Before clip + areas_before = test_boxes_5d.area() + # After clip + test_boxes_5d.clip(image_size, clip_angle_threshold) + areas_diff = test_boxes_5d.area() - areas_before + + # the areas should only decrease after clipping + self.assertTrue(torch.all(areas_diff <= 0)) + # whenever the box is clipped (thus the area shrinks), + # the angle for the box must be within the clip_angle_threshold + # Note that the clip function will normalize the angle range + # to be within (-180, 180] + + self.assertTrue( + torch.all( + torch.abs(test_boxes_5d.tensor[:, 4][torch.where(areas_diff < 0)]) + < clip_angle_threshold + ) + ) + + def test_normalize_angles(self): + # torch.manual_seed(0) + for _ in range(50): + num_boxes = 100 + boxes_5d = torch.zeros(num_boxes, 5) + boxes_5d[:, 0] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 1] = torch.FloatTensor(num_boxes).uniform_(-100, 500) + boxes_5d[:, 2] = torch.FloatTensor(num_boxes).uniform_(0, 500) + boxes_5d[:, 3] = torch.FloatTensor(num_boxes).uniform_(0, 500) + boxes_5d[:, 4] = torch.FloatTensor(num_boxes).uniform_(-1800, 1800) + rotated_boxes = RotatedBoxes(boxes_5d) + normalized_boxes = rotated_boxes.clone() + normalized_boxes.normalize_angles() + self.assertTrue(torch.all(normalized_boxes.tensor[:, 4] >= -180)) + self.assertTrue(torch.all(normalized_boxes.tensor[:, 4] < 180)) + # x, y, w, h should not change + self.assertTrue(torch.allclose(boxes_5d[:, :4], normalized_boxes.tensor[:, :4])) + # the cos/sin values of the angles should stay the same + + self.assertTrue( + torch.allclose( + torch.cos(boxes_5d[:, 4] * math.pi / 180), + torch.cos(normalized_boxes.tensor[:, 4] * math.pi / 180), + atol=1e-5, + ) + ) + + self.assertTrue( + torch.allclose( + torch.sin(boxes_5d[:, 4] * math.pi / 180), + torch.sin(normalized_boxes.tensor[:, 4] * math.pi / 180), + atol=1e-5, + ) + ) + + def test_pairwise_iou_0_degree(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor( + [[0.5, 0.5, 1.0, 1.0, 0.0], [0.5, 0.5, 1.0, 1.0, 0.0]], + dtype=torch.float32, + device=device, + ) + boxes2 = torch.tensor( + [ + [0.5, 0.5, 1.0, 1.0, 0.0], + [0.25, 0.5, 0.5, 1.0, 0.0], + [0.5, 0.25, 1.0, 0.5, 0.0], + [0.25, 0.25, 0.5, 0.5, 0.0], + [0.75, 0.75, 0.5, 0.5, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + dtype=torch.float32, + device=device, + ) + expected_ious = torch.tensor( + [ + [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)], + [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)], + ], + dtype=torch.float32, + device=device, + ) + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_45_degrees(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor( + [ + [1, 1, math.sqrt(2), math.sqrt(2), 45], + [1, 1, 2 * math.sqrt(2), 2 * math.sqrt(2), -45], + ], + dtype=torch.float32, + device=device, + ) + boxes2 = torch.tensor([[1, 1, 2, 2, 0]], dtype=torch.float32, device=device) + expected_ious = torch.tensor([[0.5], [0.5]], dtype=torch.float32, device=device) + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_orthogonal(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor([[5, 5, 10, 6, 55]], dtype=torch.float32, device=device) + boxes2 = torch.tensor([[5, 5, 10, 6, -35]], dtype=torch.float32, device=device) + iou = (6.0 * 6.0) / (6.0 * 6.0 + 4.0 * 6.0 + 4.0 * 6.0) + expected_ious = torch.tensor([[iou]], dtype=torch.float32, device=device) + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_large_close_boxes(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + boxes1 = torch.tensor( + [[299.500000, 417.370422, 600.000000, 364.259186, 27.1828]], + dtype=torch.float32, + device=device, + ) + boxes2 = torch.tensor( + [[299.500000, 417.370422, 600.000000, 364.259155, 27.1828]], + dtype=torch.float32, + device=device, + ) + iou = 364.259155 / 364.259186 + expected_ious = torch.tensor([[iou]], dtype=torch.float32, device=device) + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_many_boxes(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + num_boxes1 = 100 + num_boxes2 = 200 + boxes1 = torch.stack( + [ + torch.tensor( + [5 + 20 * i, 5 + 20 * i, 10, 10, 0], + dtype=torch.float32, + device=device, + ) + for i in range(num_boxes1) + ] + ) + boxes2 = torch.stack( + [ + torch.tensor( + [5 + 20 * i, 5 + 20 * i, 10, 1 + 9 * i / num_boxes2, 0], + dtype=torch.float32, + device=device, + ) + for i in range(num_boxes2) + ] + ) + expected_ious = torch.zeros(num_boxes1, num_boxes2, dtype=torch.float32, device=device) + for i in range(min(num_boxes1, num_boxes2)): + expected_ious[i][i] = (1 + 9 * i / num_boxes2) / 10.0 + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_issue1207_simplified(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + # Simplified test case of D2-issue-1207 + boxes1 = torch.tensor([[3, 3, 8, 2, -45.0]], device=device) + boxes2 = torch.tensor([[6, 0, 8, 2, -45.0]], device=device) + iou = 0.0 + expected_ious = torch.tensor([[iou]], dtype=torch.float32, device=device) + + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_pairwise_iou_issue1207(self): + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + # The original test case in D2-issue-1207 + boxes1 = torch.tensor([[160.0, 153.0, 230.0, 23.0, -37.0]], device=device) + boxes2 = torch.tensor([[190.0, 127.0, 80.0, 21.0, -46.0]], device=device) + + iou = 0.0 + expected_ious = torch.tensor([[iou]], dtype=torch.float32, device=device) + + ious = pairwise_iou(RotatedBoxes(boxes1), RotatedBoxes(boxes2)) + self.assertTrue(torch.allclose(ious, expected_ious)) + + def test_empty_cat(self): + x = RotatedBoxes.cat([]) + self.assertTrue(x.tensor.shape, (0, 5)) + + def test_scriptability(self): + def func(x): + boxes = RotatedBoxes(x) + test = boxes.to(torch.device("cpu")).tensor + return boxes.area(), test + + f = torch.jit.script(func) + f = reload_script_model(f) + f(torch.rand((3, 5))) + + data = torch.rand((3, 5)) + + def func_cat(x: torch.Tensor): + boxes1 = RotatedBoxes(x) + boxes2 = RotatedBoxes(x) + # this is not supported by torchscript for now. + # boxes3 = RotatedBoxes.cat([boxes1, boxes2]) + boxes3 = boxes1.cat([boxes1, boxes2]) + return boxes3 + + f = torch.jit.script(func_cat) + script_box = f(data) + self.assertTrue(torch.equal(torch.cat([data, data]), script_box.tensor)) + + +def benchmark_rotated_iou(): + num_boxes1 = 200 + num_boxes2 = 500 + boxes1 = torch.stack( + [ + torch.tensor([5 + 20 * i, 5 + 20 * i, 10, 10, 0], dtype=torch.float32) + for i in range(num_boxes1) + ] + ) + boxes2 = torch.stack( + [ + torch.tensor( + [5 + 20 * i, 5 + 20 * i, 10, 1 + 9 * i / num_boxes2, 0], + dtype=torch.float32, + ) + for i in range(num_boxes2) + ] + ) + + def func(dev, n=1): + b1 = boxes1.to(device=dev) + b2 = boxes2.to(device=dev) + + def bench(): + for _ in range(n): + pairwise_iou_rotated(b1, b2) + if dev.type == "cuda": + torch.cuda.synchronize() + + return bench + + # only run it once per timed loop, since it's slow + args = [{"dev": torch.device("cpu"), "n": 1}] + if torch.cuda.is_available(): + args.append({"dev": torch.device("cuda"), "n": 10}) + + benchmark(func, "rotated_iou", args, warmup_iters=3) + + +if __name__ == "__main__": + unittest.main() + benchmark_rotated_iou() diff --git a/data_processing/detectron2/tests/test_checkpoint.py b/data_processing/detectron2/tests/test_checkpoint.py new file mode 100644 index 0000000..6c0b1c1 --- /dev/null +++ b/data_processing/detectron2/tests/test_checkpoint.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +import tempfile +import unittest +from collections import OrderedDict +import torch +from iopath.common.file_io import PathHandler, PathManager +from torch import nn + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.checkpoint.c2_model_loading import ( + _longest_common_prefix_str, + align_and_update_state_dicts, +) +from detectron2.utils.logger import setup_logger + + +class TestCheckpointer(unittest.TestCase): + def setUp(self): + setup_logger() + + def create_complex_model(self): + m = nn.Module() + m.block1 = nn.Module() + m.block1.layer1 = nn.Linear(2, 3) + m.layer2 = nn.Linear(3, 2) + m.res = nn.Module() + m.res.layer2 = nn.Linear(3, 2) + + state_dict = OrderedDict() + state_dict["layer1.weight"] = torch.rand(3, 2) + state_dict["layer1.bias"] = torch.rand(3) + state_dict["layer2.weight"] = torch.rand(2, 3) + state_dict["layer2.bias"] = torch.rand(2) + state_dict["res.layer2.weight"] = torch.rand(2, 3) + state_dict["res.layer2.bias"] = torch.rand(2) + return m, state_dict + + def test_complex_model_loaded(self): + for add_data_parallel in [False, True]: + model, state_dict = self.create_complex_model() + if add_data_parallel: + model = nn.DataParallel(model) + model_sd = model.state_dict() + + sd_to_load = align_and_update_state_dicts(model_sd, state_dict) + model.load_state_dict(sd_to_load) + for loaded, stored in zip(model_sd.values(), state_dict.values()): + # different tensor references + self.assertFalse(id(loaded) == id(stored)) + # same content + self.assertTrue(loaded.to(stored).equal(stored)) + + def test_load_with_matching_heuristics(self): + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + model, state_dict = self.create_complex_model() + torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) + checkpointer = DetectionCheckpointer(model, save_dir=d) + + with torch.no_grad(): + # use a different weight from the `state_dict`, since torch.rand is less than 1 + model.block1.layer1.weight.fill_(1) + + # load checkpoint without matching_heuristics + checkpointer.load(os.path.join(d, "checkpoint.pth")) + self.assertTrue(model.block1.layer1.weight.equal(torch.ones(3, 2))) + + # load checkpoint with matching_heuristics + checkpointer.load(os.path.join(d, "checkpoint.pth?matching_heuristics=True")) + self.assertFalse(model.block1.layer1.weight.equal(torch.ones(3, 2))) + + def test_custom_path_manager_handler(self): + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + + class CustomPathManagerHandler(PathHandler): + PREFIX = "detectron2_test://" + + def _get_supported_prefixes(self): + return [self.PREFIX] + + def _get_local_path(self, path, **kwargs): + name = path[len(self.PREFIX) :] + return os.path.join(d, name) + + def _open(self, path, mode="r", **kwargs): + return open(self._get_local_path(path), mode, **kwargs) + + pathmgr = PathManager() + pathmgr.register_handler(CustomPathManagerHandler()) + + model, state_dict = self.create_complex_model() + torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) + checkpointer = DetectionCheckpointer(model, save_dir=d) + checkpointer.path_manager = pathmgr + checkpointer.load("detectron2_test://checkpoint.pth") + checkpointer.load("detectron2_test://checkpoint.pth?matching_heuristics=True") + + def test_lcp(self): + self.assertEqual(_longest_common_prefix_str(["class", "dlaps_model"]), "") + self.assertEqual(_longest_common_prefix_str(["classA", "classB"]), "class") + self.assertEqual(_longest_common_prefix_str(["classA", "classB", "clab"]), "cla") + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/test_engine.py b/data_processing/detectron2/tests/test_engine.py new file mode 100644 index 0000000..c97c11b --- /dev/null +++ b/data_processing/detectron2/tests/test_engine.py @@ -0,0 +1,264 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import json +import math +import os +import tempfile +import time +import unittest +from unittest import mock +import torch +from fvcore.common.checkpoint import Checkpointer +from torch import nn + +from detectron2 import model_zoo +from detectron2.config import configurable, get_cfg +from detectron2.engine import DefaultTrainer, SimpleTrainer, default_setup, hooks +from detectron2.modeling.meta_arch import META_ARCH_REGISTRY +from detectron2.utils.events import CommonMetricPrinter, JSONWriter + + +@META_ARCH_REGISTRY.register() +class _SimpleModel(nn.Module): + @configurable + def __init__(self, sleep_sec=0): + super().__init__() + self.mod = nn.Linear(10, 20) + self.sleep_sec = sleep_sec + + @classmethod + def from_config(cls, cfg): + return {} + + def forward(self, x): + if self.sleep_sec > 0: + time.sleep(self.sleep_sec) + return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])} + + +class TestTrainer(unittest.TestCase): + def _data_loader(self, device): + device = torch.device(device) + while True: + yield torch.rand(3, 3).to(device) + + def test_simple_trainer(self, device="cpu"): + model = _SimpleModel().to(device=device) + trainer = SimpleTrainer( + model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1) + ) + trainer.train(0, 10) + + def test_simple_trainer_reset_dataloader(self, device="cpu"): + model = _SimpleModel().to(device=device) + trainer = SimpleTrainer( + model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1) + ) + trainer.train(0, 10) + trainer.reset_data_loader(lambda: self._data_loader(device)) + trainer.train(0, 10) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_simple_trainer_cuda(self): + self.test_simple_trainer(device="cuda") + + def test_writer_hooks(self): + model = _SimpleModel(sleep_sec=0.1) + trainer = SimpleTrainer( + model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1) + ) + + max_iter = 50 + + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + json_file = os.path.join(d, "metrics.json") + writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)] + + trainer.register_hooks( + [hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)] + ) + with self.assertLogs(writers[0].logger) as logs: + trainer.train(0, max_iter) + + with open(json_file, "r") as f: + data = [json.loads(line.strip()) for line in f] + self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50]) + # the eval metric is in the last line with iter 50 + self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!") + + # test logged messages from CommonMetricPrinter + self.assertEqual(len(logs.output), 3) + for log, iter in zip(logs.output, [19, 39, 49]): + self.assertIn(f"iter: {iter}", log) + + self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!") + + def test_metric_gather_and_write(self): + gather_metric_period = 5 + writer_period = 10 + + model = _SimpleModel(sleep_sec=0.1) + trainer = SimpleTrainer( + model, + self._data_loader("cpu"), + torch.optim.SGD(model.parameters(), 0.1), + gather_metric_period=gather_metric_period, + ) + + max_iter = 50 + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + json_file = os.path.join(d, "metrics.json") + writers = [JSONWriter(json_file, window_size=writer_period)] + + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.PeriodicWriter(writers, period=writer_period), + ] + ) + trainer.train(0, max_iter) + + with open(json_file, "r") as f: + data = [json.loads(line.strip()) for line in f] + self.assertEqual([x["iteration"] for x in data], [9, 19, 29, 39, 49]) + self.assertEqual(len(trainer.storage.history("time").values()), 48) + for key in ["data_time", "total_loss"]: + history = trainer.storage.history(key).values() + history_iters = [h[1] for h in history] + self.assertEqual(history_iters, [4, 9, 14, 19, 24, 29, 34, 39, 44, 49]) + for i in range(len(data)): + # written metric should equal to the median of 2 most recent logged metrics + logged1, logged2 = history[2 * i][0], history[2 * i + 1][0] + gt = data[i][key] + self.assertEqual(gt, (logged1 + logged2) / 2.0) + + def test_async_write_metrics(self): + writer_period = 1 + + model = _SimpleModel(sleep_sec=0.1) + trainer = SimpleTrainer( + model, + self._data_loader("cpu"), + torch.optim.SGD(model.parameters(), 0.1), + async_write_metrics=True, + ) + + max_iter = 50 + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + json_file = os.path.join(d, "metrics.json") + writers = [JSONWriter(json_file, window_size=writer_period)] + + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.PeriodicWriter(writers, period=writer_period), + ] + ) + trainer.train(0, max_iter) + + self.assertEqual(len(trainer.storage.history("time").values()), 48) + for key in ["data_time", "total_loss"]: + history = trainer.storage.history(key).values() + history_iters = [h[1] for h in history] + self.assertEqual(history_iters, list(range(50))) + + def test_default_trainer(self): + # TODO: this test requires manifold access, so changed device to CPU. see: T88318502 + cfg = get_cfg() + cfg.MODEL.DEVICE = "cpu" + cfg.MODEL.META_ARCHITECTURE = "_SimpleModel" + cfg.DATASETS.TRAIN = ("coco_2017_val_100",) + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + cfg.OUTPUT_DIR = d + trainer = DefaultTrainer(cfg) + + # test property + self.assertIs(trainer.model, trainer._trainer.model) + trainer.model = _SimpleModel() + self.assertIs(trainer.model, trainer._trainer.model) + + def test_checkpoint_resume(self): + model = _SimpleModel() + dataloader = self._data_loader("cpu") + opt = torch.optim.SGD(model.parameters(), 0.1) + scheduler = torch.optim.lr_scheduler.StepLR(opt, 3) + + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + trainer = SimpleTrainer(model, dataloader, opt) + checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer) + + trainer.register_hooks( + [ + hooks.LRScheduler(scheduler=scheduler), + # checkpoint after scheduler to properly save the state of scheduler + hooks.PeriodicCheckpointer(checkpointer, 10), + ] + ) + + trainer.train(0, 12) + self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5) + self.assertEqual(scheduler.last_epoch, 12) + del trainer + + opt = torch.optim.SGD(model.parameters(), 999) # lr will be loaded + trainer = SimpleTrainer(model, dataloader, opt) + scheduler = torch.optim.lr_scheduler.StepLR(opt, 3) + trainer.register_hooks( + [ + hooks.LRScheduler(scheduler=scheduler), + ] + ) + checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer) + checkpointer.resume_or_load("non_exist.pth") + self.assertEqual(trainer.iter, 11) # last finished iter number (0-based in Trainer) + # number of times `scheduler.step()` was called (1-based) + self.assertEqual(scheduler.last_epoch, 12) + self.assertAlmostEqual(opt.param_groups[0]["lr"], 1e-5) + + def test_eval_hook(self): + model = _SimpleModel() + dataloader = self._data_loader("cpu") + opt = torch.optim.SGD(model.parameters(), 0.1) + + for total_iter, period, eval_count in [(30, 15, 2), (31, 15, 3), (20, 0, 1)]: + test_func = mock.Mock(return_value={"metric": 3.0}) + trainer = SimpleTrainer(model, dataloader, opt) + trainer.register_hooks([hooks.EvalHook(period, test_func)]) + trainer.train(0, total_iter) + self.assertEqual(test_func.call_count, eval_count) + + def test_best_checkpointer(self): + model = _SimpleModel() + dataloader = self._data_loader("cpu") + opt = torch.optim.SGD(model.parameters(), 0.1) + metric_name = "metric" + total_iter = 40 + test_period = 10 + test_cases = [ + ("max", iter([0.3, 0.4, 0.35, 0.5]), 3), + ("min", iter([1.0, 0.8, 0.9, 0.9]), 2), + ("min", iter([math.nan, 0.8, 0.9, 0.9]), 1), + ] + for mode, metrics, call_count in test_cases: + trainer = SimpleTrainer(model, dataloader, opt) + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + checkpointer = Checkpointer(model, d, opt=opt, trainer=trainer) + trainer.register_hooks( + [ + hooks.EvalHook(test_period, lambda: {metric_name: next(metrics)}), + hooks.BestCheckpointer(test_period, checkpointer, metric_name, mode=mode), + ] + ) + with mock.patch.object(checkpointer, "save") as mock_save_method: + trainer.train(0, total_iter) + self.assertEqual(mock_save_method.call_count, call_count) + + def test_setup_config(self): + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + cfg = get_cfg() + cfg.OUTPUT_DIR = os.path.join(d, "yacs") + default_setup(cfg, {}) + + cfg = model_zoo.get_config("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.py") + cfg.train.output_dir = os.path.join(d, "omegaconf") + default_setup(cfg, {}) diff --git a/data_processing/detectron2/tests/test_events.py b/data_processing/detectron2/tests/test_events.py new file mode 100644 index 0000000..174ca97 --- /dev/null +++ b/data_processing/detectron2/tests/test_events.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import os +import tempfile +import unittest + +from detectron2.utils.events import ( + CommonMetricPrinter, + EventStorage, + JSONWriter, + get_event_storage, + has_event_storage, +) + + +class TestEventWriter(unittest.TestCase): + def testScalar(self): + with tempfile.TemporaryDirectory( + prefix="detectron2_tests" + ) as dir, EventStorage() as storage: + json_file = os.path.join(dir, "test.json") + writer = JSONWriter(json_file) + for k in range(60): + storage.put_scalar("key", k, smoothing_hint=False) + if (k + 1) % 20 == 0: + writer.write() + storage.step() + writer.close() + with open(json_file) as f: + data = [json.loads(l) for l in f] + self.assertTrue([int(k["key"]) for k in data] == [19, 39, 59]) + + def testScalarMismatchedPeriod(self): + with tempfile.TemporaryDirectory( + prefix="detectron2_tests" + ) as dir, EventStorage() as storage: + json_file = os.path.join(dir, "test.json") + + writer = JSONWriter(json_file) + for k in range(60): + if k % 17 == 0: # write in a differnt period + storage.put_scalar("key2", k, smoothing_hint=False) + storage.put_scalar("key", k, smoothing_hint=False) + if (k + 1) % 20 == 0: + writer.write() + storage.step() + writer.close() + with open(json_file) as f: + data = [json.loads(l) for l in f] + self.assertTrue([int(k.get("key2", 0)) for k in data] == [17, 0, 34, 0, 51, 0]) + self.assertTrue([int(k.get("key", 0)) for k in data] == [0, 19, 0, 39, 0, 59]) + self.assertTrue([int(k["iteration"]) for k in data] == [17, 19, 34, 39, 51, 59]) + + def testPrintETA(self): + with EventStorage() as s: + p1 = CommonMetricPrinter(10) + p2 = CommonMetricPrinter() + + s.put_scalar("time", 1.0) + s.step() + s.put_scalar("time", 1.0) + s.step() + + with self.assertLogs("detectron2.utils.events") as logs: + p1.write() + self.assertIn("eta", logs.output[0]) + + with self.assertLogs("detectron2.utils.events") as logs: + p2.write() + self.assertNotIn("eta", logs.output[0]) + + def testPrintNonLosses(self): + with EventStorage() as s: + p1 = CommonMetricPrinter(10) + p2 = CommonMetricPrinter() + + s.put_scalar("time", 1.0) + s.put_scalar("[metric]bn_stat", 1.0) + s.step() + s.put_scalar("time", 1.0) + s.put_scalar("[metric]bn_stat", 1.0) + s.step() + + with self.assertLogs("detectron2.utils.events") as logs: + p1.write() + self.assertIn("[metric]bn_stat", logs.output[0]) + + with self.assertLogs("detectron2.utils.events") as logs: + p2.write() + self.assertIn("[metric]bn_stat", logs.output[0]) + + def testSmoothingWithWindowSize(self): + with tempfile.TemporaryDirectory( + prefix="detectron2_tests" + ) as dir, EventStorage() as storage: + json_file = os.path.join(dir, "test.json") + writer = JSONWriter(json_file, window_size=10) + for k in range(20): + storage.put_scalar("key1", k, smoothing_hint=True) + if (k + 1) % 2 == 0: + storage.put_scalar("key2", k, smoothing_hint=True) + if (k + 1) % 5 == 0: + storage.put_scalar("key3", k, smoothing_hint=True) + if (k + 1) % 10 == 0: + writer.write() + storage.step() + + num_samples = {k: storage.count_samples(k, 10) for k in ["key1", "key2", "key3"]} + self.assertEqual(num_samples, {"key1": 10, "key2": 5, "key3": 2}) + writer.close() + with open(json_file) as f: + data = [json.loads(l) for l in f] + self.assertEqual([k["key1"] for k in data], [4.5, 14.5]) + self.assertEqual([k["key2"] for k in data], [5, 15]) + self.assertEqual([k["key3"] for k in data], [6.5, 16.5]) + + def testEventStorage(self): + self.assertFalse(has_event_storage()) + with EventStorage() as storage: + self.assertTrue(has_event_storage()) + self.assertEqual(storage, get_event_storage()) + self.assertFalse(has_event_storage()) diff --git a/data_processing/detectron2/tests/test_export_caffe2.py b/data_processing/detectron2/tests/test_export_caffe2.py new file mode 100644 index 0000000..58e9f68 --- /dev/null +++ b/data_processing/detectron2/tests/test_export_caffe2.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# -*- coding: utf-8 -*- + +import copy +import os +import tempfile +import unittest +import torch +from torch.hub import _check_module_exists + +from detectron2 import model_zoo +from detectron2.utils.logger import setup_logger +from detectron2.utils.testing import get_sample_coco_image + +try: + # Caffe2 used to be included in PyTorch, but since PyTorch 1.10+, + # Caffe2 is not included in pre-built packages. This is a safety BC check + from detectron2.export import Caffe2Model, Caffe2Tracer +except ImportError: + raise unittest.SkipTest( + f"PyTorch does not have Caffe2 support. Skipping all tests in {__name__}" + ) from None + + +# TODO: this test requires manifold access, see: T88318502 +# Running it on CircleCI causes crash, not sure why. +@unittest.skipIf(os.environ.get("CIRCLECI"), "Caffe2 tests crash on CircleCI.") +@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.") +class TestCaffe2Export(unittest.TestCase): + def setUp(self): + setup_logger() + + def _test_model(self, config_path, device="cpu"): + cfg = model_zoo.get_config(config_path) + cfg.MODEL.DEVICE = device + model = model_zoo.get(config_path, trained=True, device=device) + + inputs = [{"image": get_sample_coco_image()}] + tracer = Caffe2Tracer(cfg, model, copy.deepcopy(inputs)) + + with tempfile.TemporaryDirectory(prefix="detectron2_unittest") as d: + if not os.environ.get("CI"): + # This requires onnx, which is not yet available on public CI + c2_model = tracer.export_caffe2() + c2_model.save_protobuf(d) + c2_model.save_graph(os.path.join(d, "test.svg"), inputs=copy.deepcopy(inputs)) + + c2_model = Caffe2Model.load_protobuf(d) + c2_model(inputs)[0]["instances"] + + ts_model = tracer.export_torchscript() + ts_model.save(os.path.join(d, "model.ts")) + + def testMaskRCNN(self): + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def testMaskRCNNGPU(self): + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", device="cuda") + + def testRetinaNet(self): + self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml") diff --git a/data_processing/detectron2/tests/test_export_onnx.py b/data_processing/detectron2/tests/test_export_onnx.py new file mode 100644 index 0000000..aa15e1a --- /dev/null +++ b/data_processing/detectron2/tests/test_export_onnx.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import io +import unittest +import warnings +import torch +from torch.hub import _check_module_exists + +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.export import STABLE_ONNX_OPSET_VERSION +from detectron2.export.flatten import TracingAdapter +from detectron2.export.torchscript_patch import patch_builtin_len +from detectron2.layers import ShapeSpec +from detectron2.modeling import build_model +from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead +from detectron2.structures import Boxes, Instances +from detectron2.utils.testing import ( + _pytorch1111_symbolic_opset9_repeat_interleave, + _pytorch1111_symbolic_opset9_to, + get_sample_coco_image, + has_dynamic_axes, + random_boxes, + register_custom_op_onnx_export, + skipIfOnCPUCI, + skipIfUnsupportedMinOpsetVersion, + skipIfUnsupportedMinTorchVersion, + unregister_custom_op_onnx_export, +) + + +@unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.") +@skipIfUnsupportedMinTorchVersion("1.10") +class TestONNXTracingExport(unittest.TestCase): + opset_version = STABLE_ONNX_OPSET_VERSION + + def testMaskRCNNFPN(self): + def inference_func(model, images): + with warnings.catch_warnings(record=True): + inputs = [{"image": image} for image in images] + inst = model.inference(inputs, do_postprocess=False)[0] + return [{"instances": inst}] + + self._test_model_zoo_from_config_path( + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func + ) + + @skipIfOnCPUCI + def testMaskRCNNC4(self): + def inference_func(model, image): + inputs = [{"image": image}] + return model.inference(inputs, do_postprocess=False)[0] + + self._test_model_zoo_from_config_path( + "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml", inference_func + ) + + @skipIfOnCPUCI + def testCascadeRCNN(self): + def inference_func(model, image): + inputs = [{"image": image}] + return model.inference(inputs, do_postprocess=False)[0] + + self._test_model_zoo_from_config_path( + "Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml", inference_func + ) + + def testRetinaNet(self): + def inference_func(model, image): + return model.forward([{"image": image}])[0]["instances"] + + self._test_model_zoo_from_config_path( + "COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func + ) + + @skipIfOnCPUCI + def testMaskRCNNFPN_batched(self): + def inference_func(model, image1, image2): + inputs = [{"image": image1}, {"image": image2}] + return model.inference(inputs, do_postprocess=False) + + self._test_model_zoo_from_config_path( + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, batch=2 + ) + + @skipIfUnsupportedMinOpsetVersion(16, STABLE_ONNX_OPSET_VERSION) + @skipIfUnsupportedMinTorchVersion("1.11.1") + def testMaskRCNNFPN_with_postproc(self): + def inference_func(model, image): + inputs = [{"image": image, "height": image.shape[1], "width": image.shape[2]}] + return model.inference(inputs, do_postprocess=True)[0]["instances"] + + self._test_model_zoo_from_config_path( + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", + inference_func, + ) + + def testKeypointHead(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = KRCNNConvDeconvUpsampleHead( + ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,) + ) + + def forward(self, x, predbox1, predbox2): + inst = [ + Instances((100, 100), pred_boxes=Boxes(predbox1)), + Instances((100, 100), pred_boxes=Boxes(predbox2)), + ] + ret = self.model(x, inst) + return tuple(x.pred_keypoints for x in ret) + + model = M() + model.eval() + + def gen_input(num1, num2): + feat = torch.randn((num1 + num2, 4, 14, 14)) + box1 = random_boxes(num1) + box2 = random_boxes(num2) + return feat, box1, box2 + + with patch_builtin_len(): + onnx_model = self._test_model( + model, + gen_input(1, 2), + input_names=["features", "pred_boxes", "pred_classes"], + output_names=["box1", "box2"], + dynamic_axes={ + "features": {0: "batch", 1: "static_four", 2: "height", 3: "width"}, + "pred_boxes": {0: "batch", 1: "static_four"}, + "pred_classes": {0: "batch", 1: "static_four"}, + "box1": {0: "num_instance", 1: "K", 2: "static_three"}, + "box2": {0: "num_instance", 1: "K", 2: "static_three"}, + }, + ) + + # Although ONNX models are not executable by PyTorch to verify + # support of batches with different sizes, we can verify model's IR + # does not hard-code input and/or output shapes. + # TODO: Add tests with different batch sizes when detectron2's CI + # support ONNX Runtime backend. + assert has_dynamic_axes(onnx_model) + + ################################################################################ + # Testcase internals - DO NOT add tests below this point + ################################################################################ + + def setUp(self): + register_custom_op_onnx_export("::to", _pytorch1111_symbolic_opset9_to, 9, "1.11.1") + register_custom_op_onnx_export( + "::repeat_interleave", + _pytorch1111_symbolic_opset9_repeat_interleave, + 9, + "1.11.1", + ) + + def tearDown(self): + unregister_custom_op_onnx_export("::to", 9, "1.11.1") + unregister_custom_op_onnx_export("::repeat_interleave", 9, "1.11.1") + + def _test_model( + self, + model, + inputs, + inference_func=None, + opset_version=STABLE_ONNX_OPSET_VERSION, + save_onnx_graph_path=None, + **export_kwargs, + ): + # Not imported in the beginning of file to prevent runtime errors + # for environments without ONNX. + # This testcase checks dependencies before running + import onnx # isort:skip + + f = io.BytesIO() + adapter_model = TracingAdapter(model, inputs, inference_func) + adapter_model.eval() + with torch.no_grad(): + try: + torch.onnx.enable_log() + except AttributeError: + # Older ONNX versions does not have this API + pass + torch.onnx.export( + adapter_model, + adapter_model.flattened_inputs, + f, + training=torch.onnx.TrainingMode.EVAL, + opset_version=opset_version, + verbose=True, + **export_kwargs, + ) + onnx_model = onnx.load_from_string(f.getvalue()) + assert onnx_model is not None + if save_onnx_graph_path: + onnx.save(onnx_model, save_onnx_graph_path) + return onnx_model + + def _test_model_zoo_from_config_path( + self, + config_path, + inference_func, + batch=1, + opset_version=STABLE_ONNX_OPSET_VERSION, + save_onnx_graph_path=None, + **export_kwargs, + ): + model = model_zoo.get(config_path, trained=True) + image = get_sample_coco_image() + inputs = tuple(image.clone() for _ in range(batch)) + return self._test_model( + model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs + ) + + def _test_model_from_config_path( + self, + config_path, + inference_func, + batch=1, + opset_version=STABLE_ONNX_OPSET_VERSION, + save_onnx_graph_path=None, + **export_kwargs, + ): + from projects.PointRend import point_rend # isort:skip + + cfg = get_cfg() + cfg.DATALOADER.NUM_WORKERS = 0 + point_rend.add_pointrend_config(cfg) + cfg.merge_from_file(config_path) + cfg.freeze() + model = build_model(cfg) + image = get_sample_coco_image() + inputs = tuple(image.clone() for _ in range(batch)) + return self._test_model( + model, inputs, inference_func, opset_version, save_onnx_graph_path, **export_kwargs + ) diff --git a/data_processing/detectron2/tests/test_export_torchscript.py b/data_processing/detectron2/tests/test_export_torchscript.py new file mode 100644 index 0000000..b9905a6 --- /dev/null +++ b/data_processing/detectron2/tests/test_export_torchscript.py @@ -0,0 +1,336 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import glob +import json +import os +import random +import tempfile +import unittest +import zipfile +import torch +from torch import Tensor, nn + +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.config.instantiate import dump_dataclass, instantiate +from detectron2.export import dump_torchscript_IR, scripting_with_instances +from detectron2.export.flatten import TracingAdapter, flatten_to_tuple +from detectron2.export.torchscript_patch import patch_builtin_len +from detectron2.layers import ShapeSpec +from detectron2.modeling import build_backbone +from detectron2.modeling.postprocessing import detector_postprocess +from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead +from detectron2.structures import Boxes, Instances +from detectron2.utils.env import TORCH_VERSION +from detectron2.utils.testing import ( + assert_instances_allclose, + convert_scripted_instances, + get_sample_coco_image, + random_boxes, + skipIfOnCPUCI, +) + + +""" +https://detectron2.readthedocs.io/tutorials/deployment.html +contains some explanations of this file. +""" + + +class TestScripting(unittest.TestCase): + def testMaskRCNNFPN(self): + self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") + + @skipIfOnCPUCI + def testMaskRCNNC4(self): + self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml") + + def testRetinaNet(self): + self._test_retinanet_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml") + + def _test_rcnn_model(self, config_path): + model = model_zoo.get(config_path, trained=True) + model.eval() + + fields = { + "proposal_boxes": Boxes, + "objectness_logits": Tensor, + "pred_boxes": Boxes, + "scores": Tensor, + "pred_classes": Tensor, + "pred_masks": Tensor, + } + script_model = scripting_with_instances(model, fields) + + # Test that batch inference with different shapes are supported + image = get_sample_coco_image() + small_image = nn.functional.interpolate(image, scale_factor=0.5) + inputs = [{"image": image}, {"image": small_image}] + with torch.no_grad(): + instance = model.inference(inputs, do_postprocess=False)[0] + scripted_instance = script_model.inference(inputs, do_postprocess=False)[0] + assert_instances_allclose(instance, scripted_instance) + + def _test_retinanet_model(self, config_path): + model = model_zoo.get(config_path, trained=True) + model.eval() + + fields = { + "pred_boxes": Boxes, + "scores": Tensor, + "pred_classes": Tensor, + } + script_model = scripting_with_instances(model, fields) + + img = get_sample_coco_image() + inputs = [{"image": img}] * 2 + with torch.no_grad(): + instance = model(inputs)[0]["instances"] + scripted_instance = convert_scripted_instances(script_model(inputs)[0]) + scripted_instance = detector_postprocess(scripted_instance, img.shape[1], img.shape[2]) + assert_instances_allclose(instance, scripted_instance) + # Note that the model currently cannot be saved and loaded into a new process: + # https://github.com/pytorch/pytorch/issues/46944 + + +# TODO: this test requires manifold access, see: T88318502 +class TestTracing(unittest.TestCase): + def testMaskRCNNFPN(self): + def inference_func(model, image): + inputs = [{"image": image}] + return model.inference(inputs, do_postprocess=False)[0] + + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func) + + def testMaskRCNNFPN_with_postproc(self): + def inference_func(model, image): + inputs = [{"image": image, "height": image.shape[1], "width": image.shape[2]}] + return model.inference(inputs, do_postprocess=True)[0]["instances"] + + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func) + + @skipIfOnCPUCI + def testMaskRCNNC4(self): + def inference_func(model, image): + inputs = [{"image": image}] + return model.inference(inputs, do_postprocess=False)[0] + + self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml", inference_func) + + @skipIfOnCPUCI + def testCascadeRCNN(self): + def inference_func(model, image): + inputs = [{"image": image}] + return model.inference(inputs, do_postprocess=False)[0] + + self._test_model("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml", inference_func) + + # bug fixed by https://github.com/pytorch/pytorch/pull/67734 + @unittest.skipIf(TORCH_VERSION == (1, 10) and os.environ.get("CI"), "1.10 has bugs.") + def testRetinaNet(self): + def inference_func(model, image): + return model.forward([{"image": image}])[0]["instances"] + + self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", inference_func) + + def _check_torchscript_no_hardcoded_device(self, jitfile, extract_dir, device): + zipfile.ZipFile(jitfile).extractall(extract_dir) + dir_path = os.path.join(extract_dir, os.path.splitext(os.path.basename(jitfile))[0]) + error_files = [] + for f in glob.glob(f"{dir_path}/code/**/*.py", recursive=True): + content = open(f).read() + if device in content: + error_files.append((f, content)) + if len(error_files): + msg = "\n".join(f"{f}\n{content}" for f, content in error_files) + raise ValueError(f"Found device '{device}' in following files:\n{msg}") + + def _get_device_casting_test_cases(self, model): + # Indexing operation can causes hardcoded device type before 1.10 + if not TORCH_VERSION >= (1, 10) or torch.cuda.device_count() == 0: + return [None] + + testing_devices = ["cpu", "cuda:0"] + if torch.cuda.device_count() > 1: + testing_devices.append(f"cuda:{torch.cuda.device_count() - 1}") + assert str(model.device) in testing_devices + testing_devices.remove(str(model.device)) + testing_devices = [None] + testing_devices # test no casting first + + return testing_devices + + def _test_model(self, config_path, inference_func, batch=1): + model = model_zoo.get(config_path, trained=True) + image = get_sample_coco_image() + inputs = tuple(image.clone() for _ in range(batch)) + + wrapper = TracingAdapter(model, inputs, inference_func) + wrapper.eval() + with torch.no_grad(): + # trace with smaller images, and the trace must still work + trace_inputs = tuple( + nn.functional.interpolate(image, scale_factor=random.uniform(0.5, 0.7)) + for _ in range(batch) + ) + traced_model = torch.jit.trace(wrapper, trace_inputs) + + testing_devices = self._get_device_casting_test_cases(model) + # save and load back the model in order to show traceback of TorchScript + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + basename = "model" + jitfile = f"{d}/{basename}.jit" + torch.jit.save(traced_model, jitfile) + traced_model = torch.jit.load(jitfile) + + if any(device and "cuda" in device for device in testing_devices): + self._check_torchscript_no_hardcoded_device(jitfile, d, "cuda") + + for device in testing_devices: + print(f"Testing casting to {device} for inference (traced on {model.device}) ...") + with torch.no_grad(): + outputs = inference_func(copy.deepcopy(model).to(device), *inputs) + traced_outputs = wrapper.outputs_schema(traced_model.to(device)(*inputs)) + if batch > 1: + for output, traced_output in zip(outputs, traced_outputs): + assert_instances_allclose(output, traced_output, size_as_tensor=True) + else: + assert_instances_allclose(outputs, traced_outputs, size_as_tensor=True) + + @skipIfOnCPUCI + def testMaskRCNNFPN_batched(self): + def inference_func(model, image1, image2): + inputs = [{"image": image1}, {"image": image2}] + return model.inference(inputs, do_postprocess=False) + + self._test_model( + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, batch=2 + ) + + def testKeypointHead(self): + class M(nn.Module): + def __init__(self): + super().__init__() + self.model = KRCNNConvDeconvUpsampleHead( + ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,) + ) + + def forward(self, x, predbox1, predbox2): + inst = [ + Instances((100, 100), pred_boxes=Boxes(predbox1)), + Instances((100, 100), pred_boxes=Boxes(predbox2)), + ] + ret = self.model(x, inst) + return tuple(x.pred_keypoints for x in ret) + + model = M() + model.eval() + + def gen_input(num1, num2): + feat = torch.randn((num1 + num2, 4, 14, 14)) + box1 = random_boxes(num1) + box2 = random_boxes(num2) + return feat, box1, box2 + + with torch.no_grad(), patch_builtin_len(): + trace = torch.jit.trace(model, gen_input(15, 15), check_trace=False) + + inputs = gen_input(12, 10) + trace_outputs = trace(*inputs) + true_outputs = model(*inputs) + for trace_output, true_output in zip(trace_outputs, true_outputs): + self.assertTrue(torch.allclose(trace_output, true_output)) + + +class TestTorchscriptUtils(unittest.TestCase): + # TODO: add test to dump scripting + def test_dump_IR_tracing(self): + cfg = get_cfg() + cfg.MODEL.RESNETS.DEPTH = 18 + cfg.MODEL.RESNETS.RES2_OUT_CHANNELS = 64 + + class Mod(nn.Module): + def forward(self, x): + return tuple(self.m(x).values()) + + model = Mod() + model.m = build_backbone(cfg) + model.eval() + + with torch.no_grad(): + ts_model = torch.jit.trace(model, (torch.rand(2, 3, 224, 224),)) + + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + dump_torchscript_IR(ts_model, d) + # check that the files are created + for name in ["model_ts_code", "model_ts_IR", "model_ts_IR_inlined", "model"]: + fname = os.path.join(d, name + ".txt") + self.assertTrue(os.stat(fname).st_size > 0, fname) + + def test_dump_IR_function(self): + @torch.jit.script + def gunc(x, y): + return x + y + + def func(x, y): + return x + y + gunc(x, y) + + ts_model = torch.jit.trace(func, (torch.rand(3), torch.rand(3))) + with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: + dump_torchscript_IR(ts_model, d) + for name in ["model_ts_code", "model_ts_IR", "model_ts_IR_inlined"]: + fname = os.path.join(d, name + ".txt") + self.assertTrue(os.stat(fname).st_size > 0, fname) + + def test_flatten_basic(self): + obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})] + res, schema = flatten_to_tuple(obj) + self.assertEqual(res, (3, 5, 6, 7, 9, 3)) + new_obj = schema(res) + self.assertEqual(new_obj, obj) + + _, new_schema = flatten_to_tuple(new_obj) + self.assertEqual(schema, new_schema) # test __eq__ + self._check_schema(schema) + + def _check_schema(self, schema): + dumped_schema = dump_dataclass(schema) + # Check that the schema is json-serializable + # Although in reality you might want to use yaml because it often has many levels + json.dumps(dumped_schema) + + # Check that the schema can be deserialized + new_schema = instantiate(dumped_schema) + self.assertEqual(schema, new_schema) + + def test_flatten_instances_boxes(self): + inst = Instances( + torch.tensor([5, 8]), pred_masks=torch.tensor([3]), pred_boxes=Boxes(torch.ones((1, 4))) + ) + obj = [3, ([5, 6], inst)] + res, schema = flatten_to_tuple(obj) + self.assertEqual(res[:3], (3, 5, 6)) + for r, expected in zip(res[3:], (inst.pred_boxes.tensor, inst.pred_masks, inst.image_size)): + self.assertIs(r, expected) + new_obj = schema(res) + assert_instances_allclose(new_obj[1][1], inst, rtol=0.0, size_as_tensor=True) + + self._check_schema(schema) + + def test_allow_non_tensor(self): + data = (torch.tensor([5, 8]), 3) # contains non-tensor + + class M(nn.Module): + def forward(self, input, number): + return input + + model = M() + with self.assertRaisesRegex(ValueError, "must only contain tensors"): + adap = TracingAdapter(model, data, allow_non_tensor=False) + + adap = TracingAdapter(model, data, allow_non_tensor=True) + _ = adap(*adap.flattened_inputs) + + newdata = (data[0].clone(),) + with self.assertRaisesRegex(ValueError, "cannot generalize"): + _ = adap(*newdata) diff --git a/data_processing/detectron2/tests/test_model_analysis.py b/data_processing/detectron2/tests/test_model_analysis.py new file mode 100644 index 0000000..c01b7af --- /dev/null +++ b/data_processing/detectron2/tests/test_model_analysis.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + + +import unittest +import torch +from torch import nn + +from detectron2.utils.analysis import find_unused_parameters, flop_count_operators, parameter_count +from detectron2.utils.testing import get_model_no_weights + + +class RetinaNetTest(unittest.TestCase): + def setUp(self): + self.model = get_model_no_weights("COCO-Detection/retinanet_R_50_FPN_1x.yaml") + + def test_flop(self): + # RetinaNet supports flop-counting with random inputs + inputs = [{"image": torch.rand(3, 800, 800), "test_unused": "abcd"}] + res = flop_count_operators(self.model, inputs) + self.assertEqual(int(res["conv"]), 146) # 146B flops + + def test_param_count(self): + res = parameter_count(self.model) + self.assertEqual(res[""], 37915572) + self.assertEqual(res["backbone"], 31452352) + + +class FasterRCNNTest(unittest.TestCase): + def setUp(self): + self.model = get_model_no_weights("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml") + + def test_flop(self): + # Faster R-CNN supports flop-counting with random inputs + inputs = [{"image": torch.rand(3, 800, 800)}] + res = flop_count_operators(self.model, inputs) + + # This only checks flops for backbone & proposal generator + # Flops for box head is not conv, and depends on #proposals, which is + # almost 0 for random inputs. + self.assertEqual(int(res["conv"]), 117) + + def test_flop_with_output_shape(self): + inputs = [{"image": torch.rand(3, 800, 800), "height": 700, "width": 700}] + res = flop_count_operators(self.model, inputs) + self.assertEqual(int(res["conv"]), 117) + + def test_param_count(self): + res = parameter_count(self.model) + self.assertEqual(res[""], 41699936) + self.assertEqual(res["backbone"], 26799296) + + +class MaskRCNNTest(unittest.TestCase): + def setUp(self): + self.model = get_model_no_weights("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml") + + def test_flop(self): + inputs1 = [{"image": torch.rand(3, 800, 800)}] + inputs2 = [{"image": torch.rand(3, 800, 800), "height": 700, "width": 700}] + + for inputs in [inputs1, inputs2]: + res = flop_count_operators(self.model, inputs) + # The mask head could have extra conv flops, so total >= 117 + self.assertGreaterEqual(int(res["conv"]), 117) + + +class UnusedParamTest(unittest.TestCase): + def test_unused(self): + class TestMod(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 10) + self.t = nn.Linear(10, 10) + + def forward(self, x): + return self.fc1(x).mean() + + m = TestMod() + ret = find_unused_parameters(m, torch.randn(10, 10)) + self.assertEqual(set(ret), {"t.weight", "t.bias"}) diff --git a/data_processing/detectron2/tests/test_model_zoo.py b/data_processing/detectron2/tests/test_model_zoo.py new file mode 100644 index 0000000..e3360a7 --- /dev/null +++ b/data_processing/detectron2/tests/test_model_zoo.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import unittest + +from detectron2 import model_zoo +from detectron2.config import instantiate +from detectron2.modeling import FPN, GeneralizedRCNN + +logger = logging.getLogger(__name__) + + +class TestModelZoo(unittest.TestCase): + def test_get_returns_model(self): + model = model_zoo.get("Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml", trained=False) + self.assertIsInstance(model, GeneralizedRCNN) + self.assertIsInstance(model.backbone, FPN) + + def test_get_invalid_model(self): + self.assertRaises(RuntimeError, model_zoo.get, "Invalid/config.yaml") + + def test_get_url(self): + url = model_zoo.get_checkpoint_url("Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml") + self.assertEqual( + url, + "https://dl.fbaipublicfiles.com/detectron2/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn/138602908/model_final_01ca85.pkl", # noqa + ) + url2 = model_zoo.get_checkpoint_url("Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.py") + self.assertEqual(url, url2) + + def _build_lazy_model(self, name): + cfg = model_zoo.get_config("common/models/" + name) + instantiate(cfg.model) + + def test_mask_rcnn_fpn(self): + self._build_lazy_model("mask_rcnn_fpn.py") + + def test_mask_rcnn_c4(self): + self._build_lazy_model("mask_rcnn_c4.py") + + def test_panoptic_fpn(self): + self._build_lazy_model("panoptic_fpn.py") + + def test_schedule(self): + cfg = model_zoo.get_config("common/coco_schedule.py") + for _, v in cfg.items(): + instantiate(v) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/test_packaging.py b/data_processing/detectron2/tests/test_packaging.py new file mode 100644 index 0000000..a5b1661 --- /dev/null +++ b/data_processing/detectron2/tests/test_packaging.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest + +from detectron2.utils.collect_env import collect_env_info + + +class TestProjects(unittest.TestCase): + def test_import(self): + from detectron2.projects import point_rend + + _ = point_rend.add_pointrend_config + + import detectron2.projects.deeplab as deeplab + + _ = deeplab.add_deeplab_config + + # import detectron2.projects.panoptic_deeplab as panoptic_deeplab + + # _ = panoptic_deeplab.add_panoptic_deeplab_config + + +class TestCollectEnv(unittest.TestCase): + def test(self): + _ = collect_env_info() diff --git a/data_processing/detectron2/tests/test_registry.py b/data_processing/detectron2/tests/test_registry.py new file mode 100644 index 0000000..4e425a6 --- /dev/null +++ b/data_processing/detectron2/tests/test_registry.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import unittest +import torch + +from detectron2.modeling.meta_arch import GeneralizedRCNN +from detectron2.utils.registry import _convert_target_to_string, locate + + +class A: + class B: + pass + + +class TestLocate(unittest.TestCase): + def _test_obj(self, obj): + name = _convert_target_to_string(obj) + newobj = locate(name) + self.assertIs(obj, newobj) + + def test_basic(self): + self._test_obj(GeneralizedRCNN) + + def test_inside_class(self): + # requires using __qualname__ instead of __name__ + self._test_obj(A.B) + + def test_builtin(self): + self._test_obj(len) + self._test_obj(dict) + + def test_pytorch_optim(self): + # pydoc.locate does not work for it + self._test_obj(torch.optim.SGD) + + def test_failure(self): + with self.assertRaises(ImportError): + locate("asdf") + + def test_compress_target(self): + from detectron2.data.transforms import RandomCrop + + name = _convert_target_to_string(RandomCrop) + # name shouldn't contain 'augmentation_impl' + self.assertEqual(name, "detectron2.data.transforms.RandomCrop") + self.assertIs(RandomCrop, locate(name)) diff --git a/data_processing/detectron2/tests/test_scheduler.py b/data_processing/detectron2/tests/test_scheduler.py new file mode 100644 index 0000000..5649a4a --- /dev/null +++ b/data_processing/detectron2/tests/test_scheduler.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import math +import numpy as np +from unittest import TestCase +import torch +from fvcore.common.param_scheduler import ( + CosineParamScheduler, + MultiStepParamScheduler, + StepWithFixedGammaParamScheduler, +) +from torch import nn + +from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler + + +class TestScheduler(TestCase): + def test_warmup_multistep(self): + p = nn.Parameter(torch.zeros(0)) + opt = torch.optim.SGD([p], lr=5) + + multiplier = WarmupParamScheduler( + MultiStepParamScheduler( + [1, 0.1, 0.01, 0.001], + milestones=[10, 15, 20], + num_updates=30, + ), + 0.001, + 5 / 30, + ) + sched = LRMultiplier(opt, multiplier, 30) + # This is an equivalent of: + # sched = WarmupMultiStepLR( + # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5) + + p.sum().backward() + opt.step() + + lrs = [0.005] + for _ in range(30): + sched.step() + lrs.append(opt.param_groups[0]["lr"]) + self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) + self.assertTrue(np.allclose(lrs[5:10], 5.0)) + self.assertTrue(np.allclose(lrs[10:15], 0.5)) + self.assertTrue(np.allclose(lrs[15:20], 0.05)) + self.assertTrue(np.allclose(lrs[20:], 0.005)) + + def test_warmup_cosine(self): + p = nn.Parameter(torch.zeros(0)) + opt = torch.optim.SGD([p], lr=5) + multiplier = WarmupParamScheduler( + CosineParamScheduler(1, 0), + 0.001, + 5 / 30, + ) + sched = LRMultiplier(opt, multiplier, 30) + + p.sum().backward() + opt.step() + self.assertEqual(opt.param_groups[0]["lr"], 0.005) + lrs = [0.005] + + for _ in range(30): + sched.step() + lrs.append(opt.param_groups[0]["lr"]) + for idx, lr in enumerate(lrs): + expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30)) + if idx >= 5: + self.assertAlmostEqual(lr, expected_cosine) + else: + self.assertNotAlmostEqual(lr, expected_cosine) + + def test_warmup_cosine_end_value(self): + from detectron2.config import CfgNode, get_cfg + + def _test_end_value(cfg_dict): + cfg = get_cfg() + cfg.merge_from_other_cfg(CfgNode(cfg_dict)) + + p = nn.Parameter(torch.zeros(0)) + opt = torch.optim.SGD([p], lr=cfg.SOLVER.BASE_LR) + + scheduler = build_lr_scheduler(cfg, opt) + + p.sum().backward() + opt.step() + self.assertEqual( + opt.param_groups[0]["lr"], cfg.SOLVER.BASE_LR * cfg.SOLVER.WARMUP_FACTOR + ) + + lrs = [] + for _ in range(cfg.SOLVER.MAX_ITER): + scheduler.step() + lrs.append(opt.param_groups[0]["lr"]) + + self.assertAlmostEqual(lrs[-1], cfg.SOLVER.BASE_LR_END) + + _test_end_value( + { + "SOLVER": { + "LR_SCHEDULER_NAME": "WarmupCosineLR", + "MAX_ITER": 100, + "WARMUP_ITERS": 10, + "WARMUP_FACTOR": 0.1, + "BASE_LR": 5.0, + "BASE_LR_END": 0.0, + } + } + ) + + _test_end_value( + { + "SOLVER": { + "LR_SCHEDULER_NAME": "WarmupCosineLR", + "MAX_ITER": 100, + "WARMUP_ITERS": 10, + "WARMUP_FACTOR": 0.1, + "BASE_LR": 5.0, + "BASE_LR_END": 0.5, + } + } + ) + + def test_warmup_stepwithfixedgamma(self): + p = nn.Parameter(torch.zeros(0)) + opt = torch.optim.SGD([p], lr=5) + + multiplier = WarmupParamScheduler( + StepWithFixedGammaParamScheduler( + base_value=1.0, + gamma=0.1, + num_decays=4, + num_updates=30, + ), + 0.001, + 5 / 30, + rescale_interval=True, + ) + sched = LRMultiplier(opt, multiplier, 30) + + p.sum().backward() + opt.step() + + lrs = [0.005] + for _ in range(29): + sched.step() + lrs.append(opt.param_groups[0]["lr"]) + self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001])) + self.assertTrue(np.allclose(lrs[5:10], 5.0)) + self.assertTrue(np.allclose(lrs[10:15], 0.5)) + self.assertTrue(np.allclose(lrs[15:20], 0.05)) + self.assertTrue(np.allclose(lrs[20:25], 0.005)) + self.assertTrue(np.allclose(lrs[25:], 0.0005)) + + # Calling sche.step() after the last training iteration is done will trigger IndexError + with self.assertRaises(IndexError, msg="list index out of range"): + sched.step() diff --git a/data_processing/detectron2/tests/test_solver.py b/data_processing/detectron2/tests/test_solver.py new file mode 100644 index 0000000..6b3ae84 --- /dev/null +++ b/data_processing/detectron2/tests/test_solver.py @@ -0,0 +1,66 @@ +import unittest + +from detectron2.solver.build import _expand_param_groups, reduce_param_groups + + +class TestOptimizer(unittest.TestCase): + def testExpandParamsGroups(self): + params = [ + { + "params": ["p1", "p2", "p3", "p4"], + "lr": 1.0, + "weight_decay": 3.0, + }, + { + "params": ["p2", "p3", "p5"], + "lr": 2.0, + "momentum": 2.0, + }, + { + "params": ["p1"], + "weight_decay": 4.0, + }, + ] + out = _expand_param_groups(params) + gt = [ + dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa + dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa + dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa + dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa + dict(params=["p5"], lr=2.0, momentum=2.0), # noqa + ] + self.assertEqual(out, gt) + + def testReduceParamGroups(self): + params = [ + dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa + dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa + dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa + dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa + dict(params=["p5"], lr=2.0, momentum=2.0), # noqa + ] + gt_groups = [ + { + "lr": 1.0, + "weight_decay": 4.0, + "params": ["p1"], + }, + { + "lr": 2.0, + "weight_decay": 3.0, + "momentum": 2.0, + "params": ["p2", "p6", "p3"], + }, + { + "lr": 1.0, + "weight_decay": 3.0, + "params": ["p4"], + }, + { + "lr": 2.0, + "momentum": 2.0, + "params": ["p5"], + }, + ] + out = reduce_param_groups(params) + self.assertEqual(out, gt_groups) diff --git a/data_processing/detectron2/tests/test_visualizer.py b/data_processing/detectron2/tests/test_visualizer.py new file mode 100644 index 0000000..646e5f3 --- /dev/null +++ b/data_processing/detectron2/tests/test_visualizer.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import numpy as np +import os +import tempfile +import unittest +import cv2 +import torch + +from detectron2.data import MetadataCatalog +from detectron2.structures import BoxMode, Instances, RotatedBoxes +from detectron2.utils.visualizer import ColorMode, Visualizer + + +class TestVisualizer(unittest.TestCase): + def _random_data(self): + H, W = 100, 100 + N = 10 + img = np.random.rand(H, W, 3) * 255 + boxxy = np.random.rand(N, 2) * (H // 2) + boxes = np.concatenate((boxxy, boxxy + H // 2), axis=1) + + def _rand_poly(): + return np.random.rand(3, 2).flatten() * H + + polygons = [[_rand_poly() for _ in range(np.random.randint(1, 5))] for _ in range(N)] + + mask = np.zeros_like(img[:, :, 0], dtype=bool) + mask[:40, 10:20] = 1 + + labels = [str(i) for i in range(N)] + return img, boxes, labels, polygons, [mask] * N + + @property + def metadata(self): + return MetadataCatalog.get("coco_2017_train") + + def test_draw_dataset_dict(self): + img = np.random.rand(512, 512, 3) * 255 + dic = { + "annotations": [ + { + "bbox": [ + 368.9946492271106, + 330.891438763377, + 13.148537455410235, + 13.644708680142685, + ], + "bbox_mode": BoxMode.XYWH_ABS, + "category_id": 0, + "iscrowd": 1, + "segmentation": { + "counts": "_jh52m?2N2N2N2O100O10O001N1O2MceP2", + "size": [512, 512], + }, + } + ], + "height": 512, + "image_id": 1, + "width": 512, + } + v = Visualizer(img) + v.draw_dataset_dict(dic) + + v = Visualizer(img, self.metadata) + v.draw_dataset_dict(dic) + + def test_draw_rotated_dataset_dict(self): + img = np.random.rand(512, 512, 3) * 255 + dic = { + "annotations": [ + { + "bbox": [ + 368.9946492271106, + 330.891438763377, + 13.148537455410235, + 13.644708680142685, + 45.0, + ], + "bbox_mode": BoxMode.XYWHA_ABS, + "category_id": 0, + "iscrowd": 1, + } + ], + "height": 512, + "image_id": 1, + "width": 512, + } + v = Visualizer(img, self.metadata) + v.draw_dataset_dict(dic) + + def test_overlay_instances(self): + img, boxes, labels, polygons, masks = self._random_data() + + v = Visualizer(img, self.metadata) + output = v.overlay_instances(masks=polygons, boxes=boxes, labels=labels).get_image() + self.assertEqual(output.shape, img.shape) + + # Test 2x scaling + v = Visualizer(img, self.metadata, scale=2.0) + output = v.overlay_instances(masks=polygons, boxes=boxes, labels=labels).get_image() + self.assertEqual(output.shape[0], img.shape[0] * 2) + + # Test overlay masks + v = Visualizer(img, self.metadata) + output = v.overlay_instances(masks=masks, boxes=boxes, labels=labels).get_image() + self.assertEqual(output.shape, img.shape) + + def test_overlay_instances_no_boxes(self): + img, boxes, labels, polygons, _ = self._random_data() + v = Visualizer(img, self.metadata) + v.overlay_instances(masks=polygons, boxes=None, labels=labels).get_image() + + def test_draw_instance_predictions(self): + img, boxes, _, _, masks = self._random_data() + num_inst = len(boxes) + inst = Instances((img.shape[0], img.shape[1])) + inst.pred_classes = torch.randint(0, 80, size=(num_inst,)) + inst.scores = torch.rand(num_inst) + inst.pred_boxes = torch.from_numpy(boxes) + inst.pred_masks = torch.from_numpy(np.asarray(masks)) + + v = Visualizer(img) + v.draw_instance_predictions(inst) + + v = Visualizer(img, self.metadata) + v.draw_instance_predictions(inst) + + def test_BWmode_nomask(self): + img, boxes, _, _, masks = self._random_data() + num_inst = len(boxes) + inst = Instances((img.shape[0], img.shape[1])) + inst.pred_classes = torch.randint(0, 80, size=(num_inst,)) + inst.scores = torch.rand(num_inst) + inst.pred_boxes = torch.from_numpy(boxes) + + v = Visualizer(img, self.metadata, instance_mode=ColorMode.IMAGE_BW) + v.draw_instance_predictions(inst) + + # check that output is grayscale + inst = inst[:0] + v = Visualizer(img, self.metadata, instance_mode=ColorMode.IMAGE_BW) + output = v.draw_instance_predictions(inst).get_image() + self.assertTrue(np.allclose(output[:, :, 0], output[:, :, 1])) + self.assertTrue(np.allclose(output[:, :, 0], output[:, :, 2])) + + def test_draw_empty_mask_predictions(self): + img, boxes, _, _, masks = self._random_data() + num_inst = len(boxes) + inst = Instances((img.shape[0], img.shape[1])) + inst.pred_classes = torch.randint(0, 80, size=(num_inst,)) + inst.scores = torch.rand(num_inst) + inst.pred_boxes = torch.from_numpy(boxes) + inst.pred_masks = torch.from_numpy(np.zeros_like(np.asarray(masks))) + + v = Visualizer(img, self.metadata) + v.draw_instance_predictions(inst) + + def test_correct_output_shape(self): + img = np.random.rand(928, 928, 3) * 255 + v = Visualizer(img, self.metadata) + out = v.output.get_image() + self.assertEqual(out.shape, img.shape) + + def test_overlay_rotated_instances(self): + H, W = 100, 150 + img = np.random.rand(H, W, 3) * 255 + num_boxes = 50 + boxes_5d = torch.zeros(num_boxes, 5) + boxes_5d[:, 0] = torch.FloatTensor(num_boxes).uniform_(-0.1 * W, 1.1 * W) + boxes_5d[:, 1] = torch.FloatTensor(num_boxes).uniform_(-0.1 * H, 1.1 * H) + boxes_5d[:, 2] = torch.FloatTensor(num_boxes).uniform_(0, max(W, H)) + boxes_5d[:, 3] = torch.FloatTensor(num_boxes).uniform_(0, max(W, H)) + boxes_5d[:, 4] = torch.FloatTensor(num_boxes).uniform_(-1800, 1800) + rotated_boxes = RotatedBoxes(boxes_5d) + labels = [str(i) for i in range(num_boxes)] + + v = Visualizer(img, self.metadata) + output = v.overlay_instances(boxes=rotated_boxes, labels=labels).get_image() + self.assertEqual(output.shape, img.shape) + + def test_draw_no_metadata(self): + img, boxes, _, _, masks = self._random_data() + num_inst = len(boxes) + inst = Instances((img.shape[0], img.shape[1])) + inst.pred_classes = torch.randint(0, 80, size=(num_inst,)) + inst.scores = torch.rand(num_inst) + inst.pred_boxes = torch.from_numpy(boxes) + inst.pred_masks = torch.from_numpy(np.asarray(masks)) + + v = Visualizer(img, MetadataCatalog.get("asdfasdf")) + v.draw_instance_predictions(inst) + + def test_draw_binary_mask(self): + img, boxes, _, _, masks = self._random_data() + img[:, :, 0] = 0 # remove red color + mask = masks[0] + mask_with_hole = np.zeros_like(mask).astype("uint8") + mask_with_hole = cv2.rectangle(mask_with_hole, (10, 10), (50, 50), 1, 5) + + for m in [mask, mask_with_hole]: + for save in [True, False]: + v = Visualizer(img) + o = v.draw_binary_mask(m, color="red", text="test") + if save: + with tempfile.TemporaryDirectory(prefix="detectron2_viz") as d: + path = os.path.join(d, "output.png") + o.save(path) + o = cv2.imread(path)[:, :, ::-1] + else: + o = o.get_image().astype("float32") + # red color is drawn on the image + self.assertTrue(o[:, :, 0].sum() > 0) + + def test_draw_soft_mask(self): + img = np.random.rand(100, 100, 3) * 255 + img[:, :, 0] = 0 # remove red color + mask = np.zeros((100, 100), dtype=np.float32) + mask[30:50, 40:50] = 1.0 + cv2.GaussianBlur(mask, (21, 21), 10) + + v = Visualizer(img) + o = v.draw_soft_mask(mask, color="red", text="test") + o = o.get_image().astype("float32") + # red color is drawn on the image + self.assertTrue(o[:, :, 0].sum() > 0) + + # test draw empty mask + v = Visualizer(img) + o = v.draw_soft_mask(np.zeros((100, 100), dtype=np.float32), color="red", text="test") + o = o.get_image().astype("float32") + + def test_border_mask_with_holes(self): + H, W = 200, 200 + img = np.zeros((H, W, 3)) + img[:, :, 0] = 255.0 + v = Visualizer(img, scale=3) + + mask = np.zeros((H, W)) + mask[:, 100:150] = 1 + # create a hole, to trigger imshow + mask = cv2.rectangle(mask, (110, 110), (130, 130), 0, thickness=-1) + output = v.draw_binary_mask(mask, color="blue") + output = output.get_image()[:, :, ::-1] + + first_row = {tuple(x.tolist()) for x in output[0]} + last_row = {tuple(x.tolist()) for x in output[-1]} + # Check quantization / off-by-1 error: the first and last row must have two colors + self.assertEqual(len(last_row), 2) + self.assertEqual(len(first_row), 2) + self.assertIn((0, 0, 255), last_row) + self.assertIn((0, 0, 255), first_row) + + def test_border_polygons(self): + H, W = 200, 200 + img = np.zeros((H, W, 3)) + img[:, :, 0] = 255.0 + v = Visualizer(img, scale=3) + mask = np.zeros((H, W)) + mask[:, 100:150] = 1 + + output = v.draw_binary_mask(mask, color="blue") + output = output.get_image()[:, :, ::-1] + + first_row = {tuple(x.tolist()) for x in output[0]} + last_row = {tuple(x.tolist()) for x in output[-1]} + # Check quantization / off-by-1 error: + # the first and last row must have >=2 colors, because the polygon + # touches both rows + self.assertGreaterEqual(len(last_row), 2) + self.assertGreaterEqual(len(first_row), 2) + self.assertIn((0, 0, 255), last_row) + self.assertIn((0, 0, 255), first_row) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/tracking/__init__.py b/data_processing/detectron2/tests/tracking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tests/tracking/test_bbox_iou_tracker.py b/data_processing/detectron2/tests/tracking/test_bbox_iou_tracker.py new file mode 100644 index 0000000..e720b2e --- /dev/null +++ b/data_processing/detectron2/tests/tracking/test_bbox_iou_tracker.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest +from copy import deepcopy +from typing import Dict +import torch + +from detectron2.config import CfgNode as CfgNode_ +from detectron2.config import instantiate +from detectron2.structures import Boxes, Instances +from detectron2.tracking.base_tracker import build_tracker_head +from detectron2.tracking.bbox_iou_tracker import BBoxIOUTracker # noqa + + +class TestBBoxIOUTracker(unittest.TestCase): + def setUp(self): + self._img_size = np.array([600, 800]) + self._prev_boxes = np.array( + [ + [101, 101, 200, 200], + [301, 301, 450, 450], + ] + ).astype(np.float32) + self._prev_scores = np.array([0.9, 0.9]) + self._prev_classes = np.array([1, 1]) + self._prev_masks = np.ones((2, 600, 800)).astype("uint8") + self._curr_boxes = np.array( + [ + [302, 303, 451, 452], + [101, 102, 201, 203], + ] + ).astype(np.float32) + self._curr_scores = np.array([0.95, 0.85]) + self._curr_classes = np.array([1, 1]) + self._curr_masks = np.ones((2, 600, 800)).astype("uint8") + + self._prev_instances = { + "image_size": self._img_size, + "pred_boxes": self._prev_boxes, + "scores": self._prev_scores, + "pred_classes": self._prev_classes, + "pred_masks": self._prev_masks, + } + self._prev_instances = self._convertDictPredictionToInstance(self._prev_instances) + self._curr_instances = { + "image_size": self._img_size, + "pred_boxes": self._curr_boxes, + "scores": self._curr_scores, + "pred_classes": self._curr_classes, + "pred_masks": self._curr_masks, + } + self._curr_instances = self._convertDictPredictionToInstance(self._curr_instances) + + self._max_num_instances = 200 + self._max_lost_frame_count = 0 + self._min_box_rel_dim = 0.02 + self._min_instance_period = 1 + self._track_iou_threshold = 0.5 + + def _convertDictPredictionToInstance(self, prediction: Dict) -> Instances: + """ + convert prediction from Dict to D2 Instances format + """ + res = Instances( + image_size=torch.IntTensor(prediction["image_size"]), + pred_boxes=Boxes(torch.FloatTensor(prediction["pred_boxes"])), + pred_masks=torch.IntTensor(prediction["pred_masks"]), + pred_classes=torch.IntTensor(prediction["pred_classes"]), + scores=torch.FloatTensor(prediction["scores"]), + ) + return res + + def test_init(self): + cfg = { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_from_config(self): + cfg = CfgNode_() + cfg.TRACKER_HEADS = CfgNode_() + cfg.TRACKER_HEADS.TRACKER_NAME = "BBoxIOUTracker" + cfg.TRACKER_HEADS.VIDEO_HEIGHT = int(self._img_size[0]) + cfg.TRACKER_HEADS.VIDEO_WIDTH = int(self._img_size[1]) + cfg.TRACKER_HEADS.MAX_NUM_INSTANCES = self._max_num_instances + cfg.TRACKER_HEADS.MAX_LOST_FRAME_COUNT = self._max_lost_frame_count + cfg.TRACKER_HEADS.MIN_BOX_REL_DIM = self._min_box_rel_dim + cfg.TRACKER_HEADS.MIN_INSTANCE_PERIOD = self._min_instance_period + cfg.TRACKER_HEADS.TRACK_IOU_THRESHOLD = self._track_iou_threshold + tracker = build_tracker_head(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_initialize_extra_fields(self): + cfg = { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + instances = tracker._initialize_extra_fields(self._curr_instances) + self.assertTrue(instances.has("ID")) + self.assertTrue(instances.has("ID_period")) + self.assertTrue(instances.has("lost_frame_count")) + + def test_assign_new_id(self): + cfg = { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + instances = deepcopy(self._curr_instances) + instances = tracker._initialize_extra_fields(instances) + instances = tracker._assign_new_id(instances) + self.assertTrue(len(instances.ID) == 2) + self.assertTrue(instances.ID[0] == 2) + self.assertTrue(instances.ID[1] == 3) + + def test_update(self): + cfg = { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker.update(self._prev_instances) + self.assertTrue(len(prev_instances.ID) == 2) + self.assertTrue(prev_instances.ID[0] == 0) + self.assertTrue(prev_instances.ID[1] == 1) + curr_instances = tracker.update(self._curr_instances) + self.assertTrue(len(curr_instances.ID) == 2) + self.assertTrue(curr_instances.ID[0] == 1) + self.assertTrue(curr_instances.ID[1] == 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/tracking/test_hungarian_tracker.py b/data_processing/detectron2/tests/tracking/test_hungarian_tracker.py new file mode 100644 index 0000000..660c635 --- /dev/null +++ b/data_processing/detectron2/tests/tracking/test_hungarian_tracker.py @@ -0,0 +1,102 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import unittest +from typing import Dict +import torch + +from detectron2.config import instantiate +from detectron2.structures import Boxes, Instances + + +class TestBaseHungarianTracker(unittest.TestCase): + def setUp(self): + self._img_size = np.array([600, 800]) + self._prev_boxes = np.array( + [ + [101, 101, 200, 200], + [301, 301, 450, 450], + ] + ).astype(np.float32) + self._prev_scores = np.array([0.9, 0.9]) + self._prev_classes = np.array([1, 1]) + self._prev_masks = np.ones((2, 600, 800)).astype("uint8") + self._curr_boxes = np.array( + [ + [302, 303, 451, 452], + [101, 102, 201, 203], + ] + ).astype(np.float32) + self._curr_scores = np.array([0.95, 0.85]) + self._curr_classes = np.array([1, 1]) + self._curr_masks = np.ones((2, 600, 800)).astype("uint8") + + self._prev_instances = { + "image_size": self._img_size, + "pred_boxes": self._prev_boxes, + "scores": self._prev_scores, + "pred_classes": self._prev_classes, + "pred_masks": self._prev_masks, + } + self._prev_instances = self._convertDictPredictionToInstance(self._prev_instances) + self._curr_instances = { + "image_size": self._img_size, + "pred_boxes": self._curr_boxes, + "scores": self._curr_scores, + "pred_classes": self._curr_classes, + "pred_masks": self._curr_masks, + } + self._curr_instances = self._convertDictPredictionToInstance(self._curr_instances) + + self._max_num_instances = 200 + self._max_lost_frame_count = 0 + self._min_box_rel_dim = 0.02 + self._min_instance_period = 1 + self._track_iou_threshold = 0.5 + + def _convertDictPredictionToInstance(self, prediction: Dict) -> Instances: + """ + convert prediction from Dict to D2 Instances format + """ + res = Instances( + image_size=torch.IntTensor(prediction["image_size"]), + pred_boxes=Boxes(torch.FloatTensor(prediction["pred_boxes"])), + pred_masks=torch.IntTensor(prediction["pred_masks"]), + pred_classes=torch.IntTensor(prediction["pred_classes"]), + scores=torch.FloatTensor(prediction["scores"]), + ) + return res + + def test_init(self): + cfg = { + "_target_": "detectron2.tracking.hungarian_tracker.BaseHungarianTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_initialize_extra_fields(self): + cfg = { + "_target_": "detectron2.tracking.hungarian_tracker.BaseHungarianTracker", + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + instances = tracker._initialize_extra_fields(self._curr_instances) + self.assertTrue(instances.has("ID")) + self.assertTrue(instances.has("ID_period")) + self.assertTrue(instances.has("lost_frame_count")) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/tracking/test_iou_weighted_hungarian_bbox_iou_tracker.py b/data_processing/detectron2/tests/tracking/test_iou_weighted_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000..6947399 --- /dev/null +++ b/data_processing/detectron2/tests/tracking/test_iou_weighted_hungarian_bbox_iou_tracker.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import numpy as np +import unittest +from typing import Dict +import torch + +from detectron2.config import CfgNode as CfgNode_ +from detectron2.config import instantiate +from detectron2.structures import Boxes, Instances +from detectron2.tracking.base_tracker import build_tracker_head +from detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker import ( # noqa + IOUWeightedHungarianBBoxIOUTracker, +) + + +class TestIOUWeightedHungarianBBoxIOUTracker(unittest.TestCase): + def setUp(self): + self._img_size = np.array([600, 800]) + self._prev_boxes = np.array( + [ + [101, 101, 200, 200], + [301, 301, 450, 450], + ] + ).astype(np.float32) + self._prev_scores = np.array([0.9, 0.9]) + self._prev_classes = np.array([1, 1]) + self._prev_masks = np.ones((2, 600, 800)).astype("uint8") + self._curr_boxes = np.array( + [ + [302, 303, 451, 452], + [101, 102, 201, 203], + ] + ).astype(np.float32) + self._curr_scores = np.array([0.95, 0.85]) + self._curr_classes = np.array([1, 1]) + self._curr_masks = np.ones((2, 600, 800)).astype("uint8") + + self._prev_instances = { + "image_size": self._img_size, + "pred_boxes": self._prev_boxes, + "scores": self._prev_scores, + "pred_classes": self._prev_classes, + "pred_masks": self._prev_masks, + } + self._prev_instances = self._convertDictPredictionToInstance(self._prev_instances) + self._curr_instances = { + "image_size": self._img_size, + "pred_boxes": self._curr_boxes, + "scores": self._curr_scores, + "pred_classes": self._curr_classes, + "pred_masks": self._curr_masks, + } + self._curr_instances = self._convertDictPredictionToInstance(self._curr_instances) + + self._max_num_instances = 10 + self._max_lost_frame_count = 3 + self._min_box_rel_dim = 0.02 + self._min_instance_period = 1 + self._track_iou_threshold = 0.5 + + def _convertDictPredictionToInstance(self, prediction: Dict) -> Instances: + """ + convert prediction from Dict to D2 Instances format + """ + res = Instances( + image_size=torch.IntTensor(prediction["image_size"]), + pred_boxes=Boxes(torch.FloatTensor(prediction["pred_boxes"])), + pred_masks=torch.IntTensor(prediction["pred_masks"]), + pred_classes=torch.IntTensor(prediction["pred_classes"]), + scores=torch.FloatTensor(prediction["scores"]), + ) + return res + + def test_init(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_from_config(self): + cfg = CfgNode_() + cfg.TRACKER_HEADS = CfgNode_() + cfg.TRACKER_HEADS.TRACKER_NAME = "IOUWeightedHungarianBBoxIOUTracker" + cfg.TRACKER_HEADS.VIDEO_HEIGHT = int(self._img_size[0]) + cfg.TRACKER_HEADS.VIDEO_WIDTH = int(self._img_size[1]) + cfg.TRACKER_HEADS.MAX_NUM_INSTANCES = self._max_num_instances + cfg.TRACKER_HEADS.MAX_LOST_FRAME_COUNT = self._max_lost_frame_count + cfg.TRACKER_HEADS.MIN_BOX_REL_DIM = self._min_box_rel_dim + cfg.TRACKER_HEADS.MIN_INSTANCE_PERIOD = self._min_instance_period + cfg.TRACKER_HEADS.TRACK_IOU_THRESHOLD = self._track_iou_threshold + tracker = build_tracker_head(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_initialize_extra_fields(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + instances = tracker._initialize_extra_fields(self._curr_instances) + self.assertTrue(instances.has("ID")) + self.assertTrue(instances.has("ID_period")) + self.assertTrue(instances.has("lost_frame_count")) + + def test_process_matched_idx(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + self.assertTrue(curr_instances.ID[0] == 1) + + def test_process_unmatched_idx(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + curr_instances = tracker._process_unmatched_idx(curr_instances, matched_idx) + self.assertTrue(curr_instances.ID[1] == 2) + + def test_process_unmatched_prev_idx(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + prev_instances.ID_period = [3, 3] + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + curr_instances = tracker._process_unmatched_idx(curr_instances, matched_idx) + curr_instances = tracker._process_unmatched_prev_idx(curr_instances, matched_prev_idx) + self.assertTrue(curr_instances.ID[2] == 0) + + def test_assign_cost_matrix_values(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + pair1 = {"idx": 0, "prev_idx": 1, "IoU": 0.6} + pair2 = {"idx": 1, "prev_idx": 0, "IoU": 0.8} + bbox_pairs = [pair1, pair2] + cost_matrix = np.full((2, 2), np.inf) + target_matrix = copy.deepcopy(cost_matrix) + target_matrix[0, 1] = -0.6 + target_matrix[1, 0] = -0.8 + cost_matrix = tracker.assign_cost_matrix_values(cost_matrix, bbox_pairs) + self.assertTrue(np.allclose(cost_matrix, target_matrix)) + + def test_update(self): + cfg = { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + _ = tracker.update(self._prev_instances) + curr_instances = tracker.update(self._curr_instances) + self.assertTrue(curr_instances.ID[0] == 1) + self.assertTrue(curr_instances.ID[1] == 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/tracking/test_vanilla_hungarian_bbox_iou_tracker.py b/data_processing/detectron2/tests/tracking/test_vanilla_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000..c33e3d9 --- /dev/null +++ b/data_processing/detectron2/tests/tracking/test_vanilla_hungarian_bbox_iou_tracker.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import numpy as np +import unittest +from typing import Dict +import torch + +from detectron2.config import CfgNode as CfgNode_ +from detectron2.config import instantiate +from detectron2.structures import Boxes, Instances +from detectron2.tracking.base_tracker import build_tracker_head +from detectron2.tracking.vanilla_hungarian_bbox_iou_tracker import ( # noqa + VanillaHungarianBBoxIOUTracker, +) + + +class TestVanillaHungarianBBoxIOUTracker(unittest.TestCase): + def setUp(self): + self._img_size = np.array([600, 800]) + self._prev_boxes = np.array( + [ + [101, 101, 200, 200], + [301, 301, 450, 450], + ] + ).astype(np.float32) + self._prev_scores = np.array([0.9, 0.9]) + self._prev_classes = np.array([1, 1]) + self._prev_masks = np.ones((2, 600, 800)).astype("uint8") + self._curr_boxes = np.array( + [ + [302, 303, 451, 452], + [101, 102, 201, 203], + ] + ).astype(np.float32) + self._curr_scores = np.array([0.95, 0.85]) + self._curr_classes = np.array([1, 1]) + self._curr_masks = np.ones((2, 600, 800)).astype("uint8") + + self._prev_instances = { + "image_size": self._img_size, + "pred_boxes": self._prev_boxes, + "scores": self._prev_scores, + "pred_classes": self._prev_classes, + "pred_masks": self._prev_masks, + } + self._prev_instances = self._convertDictPredictionToInstance(self._prev_instances) + self._curr_instances = { + "image_size": self._img_size, + "pred_boxes": self._curr_boxes, + "scores": self._curr_scores, + "pred_classes": self._curr_classes, + "pred_masks": self._curr_masks, + } + self._curr_instances = self._convertDictPredictionToInstance(self._curr_instances) + + self._max_num_instances = 10 + self._max_lost_frame_count = 3 + self._min_box_rel_dim = 0.02 + self._min_instance_period = 1 + self._track_iou_threshold = 0.5 + + def _convertDictPredictionToInstance(self, prediction: Dict) -> Instances: + """ + convert prediction from Dict to D2 Instances format + """ + res = Instances( + image_size=torch.IntTensor(prediction["image_size"]), + pred_boxes=Boxes(torch.FloatTensor(prediction["pred_boxes"])), + pred_masks=torch.IntTensor(prediction["pred_masks"]), + pred_classes=torch.IntTensor(prediction["pred_classes"]), + scores=torch.FloatTensor(prediction["scores"]), + ) + return res + + def test_init(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_from_config(self): + cfg = CfgNode_() + cfg.TRACKER_HEADS = CfgNode_() + cfg.TRACKER_HEADS.TRACKER_NAME = "VanillaHungarianBBoxIOUTracker" + cfg.TRACKER_HEADS.VIDEO_HEIGHT = int(self._img_size[0]) + cfg.TRACKER_HEADS.VIDEO_WIDTH = int(self._img_size[1]) + cfg.TRACKER_HEADS.MAX_NUM_INSTANCES = self._max_num_instances + cfg.TRACKER_HEADS.MAX_LOST_FRAME_COUNT = self._max_lost_frame_count + cfg.TRACKER_HEADS.MIN_BOX_REL_DIM = self._min_box_rel_dim + cfg.TRACKER_HEADS.MIN_INSTANCE_PERIOD = self._min_instance_period + cfg.TRACKER_HEADS.TRACK_IOU_THRESHOLD = self._track_iou_threshold + tracker = build_tracker_head(cfg) + self.assertTrue(tracker._video_height == self._img_size[0]) + + def test_initialize_extra_fields(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + instances = tracker._initialize_extra_fields(self._curr_instances) + self.assertTrue(instances.has("ID")) + self.assertTrue(instances.has("ID_period")) + self.assertTrue(instances.has("lost_frame_count")) + + def test_process_matched_idx(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + self.assertTrue(curr_instances.ID[0] == 1) + + def test_process_unmatched_idx(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + curr_instances = tracker._process_unmatched_idx(curr_instances, matched_idx) + self.assertTrue(curr_instances.ID[1] == 2) + + def test_process_unmatched_prev_idx(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + prev_instances = tracker._initialize_extra_fields(self._prev_instances) + prev_instances.ID_period = [3, 3] + tracker._prev_instances = prev_instances + curr_instances = tracker._initialize_extra_fields(self._curr_instances) + matched_idx = np.array([0]) + matched_prev_idx = np.array([1]) + curr_instances = tracker._process_matched_idx(curr_instances, matched_idx, matched_prev_idx) + curr_instances = tracker._process_unmatched_idx(curr_instances, matched_idx) + curr_instances = tracker._process_unmatched_prev_idx(curr_instances, matched_prev_idx) + self.assertTrue(curr_instances.ID[2] == 0) + + def test_assign_cost_matrix_values(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + pair1 = {"idx": 0, "prev_idx": 1} + pair2 = {"idx": 1, "prev_idx": 0} + bbox_pairs = [pair1, pair2] + cost_matrix = np.full((2, 2), np.inf) + target_matrix = copy.deepcopy(cost_matrix) + target_matrix[0, 1] = -1 + target_matrix[1, 0] = -1 + cost_matrix = tracker.assign_cost_matrix_values(cost_matrix, bbox_pairs) + self.assertTrue(np.allclose(cost_matrix, target_matrix)) + + def test_update(self): + cfg = { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": self._img_size[0], + "video_width": self._img_size[1], + "max_num_instances": self._max_num_instances, + "max_lost_frame_count": self._max_lost_frame_count, + "min_box_rel_dim": self._min_box_rel_dim, + "min_instance_period": self._min_instance_period, + "track_iou_threshold": self._track_iou_threshold, + } + tracker = instantiate(cfg) + _ = tracker.update(self._prev_instances) + curr_instances = tracker.update(self._curr_instances) + self.assertTrue(curr_instances.ID[0] == 1) + self.assertTrue(curr_instances.ID[1] == 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_processing/detectron2/tests/utils/test_tensorboardx.py b/data_processing/detectron2/tests/utils/test_tensorboardx.py new file mode 100644 index 0000000..885fb8d --- /dev/null +++ b/data_processing/detectron2/tests/utils/test_tensorboardx.py @@ -0,0 +1,23 @@ +import os +import tempfile +import unittest + +from detectron2.utils.events import TensorboardXWriter + + +# TODO Fix up capitalization +class TestTensorboardXWriter(unittest.TestCase): + def test_no_files_created(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + writer = TensorboardXWriter(tmp_dir) + writer.close() + + self.assertFalse(os.listdir(tmp_dir)) + + def test_single_write(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + writer = TensorboardXWriter(tmp_dir) + writer._writer.add_scalar("testing", 1, 1) + writer.close() + + self.assertTrue(os.listdir(tmp_dir)) diff --git a/data_processing/detectron2/tools/README.md b/data_processing/detectron2/tools/README.md new file mode 100644 index 0000000..0b40d53 --- /dev/null +++ b/data_processing/detectron2/tools/README.md @@ -0,0 +1,49 @@ + +This directory contains a few example scripts that demonstrate features of detectron2. + + +* `train_net.py` + +An example training script that's made to train builtin models of detectron2. + +For usage, see [GETTING_STARTED.md](../GETTING_STARTED.md). + +* `plain_train_net.py` + +Similar to `train_net.py`, but implements a training loop instead of using `Trainer`. +This script includes fewer features but it may be more friendly to hackers. + +* `benchmark.py` + +Benchmark the training speed, inference speed or data loading speed of a given config. + +Usage: +``` +python benchmark.py --config-file config.yaml --task train/eval/data [optional DDP flags] +``` + +* `analyze_model.py` + +Analyze FLOPs, parameters, activations of a detectron2 model. See its `--help` for usage. + +* `visualize_json_results.py` + +Visualize the json instance detection/segmentation results dumped by `COCOEvalutor` or `LVISEvaluator` + +Usage: +``` +python visualize_json_results.py --input x.json --output dir/ --dataset coco_2017_val +``` +If not using a builtin dataset, you'll need your own script or modify this script. + +* `visualize_data.py` + +Visualize ground truth raw annotations or training data (after preprocessing/augmentations). + +Usage: +``` +python visualize_data.py --config-file config.yaml --source annotation/dataloader --output-dir dir/ [--show] +``` + +NOTE: the script does not stop by itself when using `--source dataloader` because a training +dataloader is usually infinite. diff --git a/data_processing/detectron2/tools/__init__.py b/data_processing/detectron2/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/detectron2/tools/analyze_model.py b/data_processing/detectron2/tools/analyze_model.py new file mode 100644 index 0000000..8e38f8b --- /dev/null +++ b/data_processing/detectron2/tools/analyze_model.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +from collections import Counter +import tqdm +from fvcore.nn import flop_count_table # can also try flop_count_str + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate +from detectron2.data import build_detection_test_loader +from detectron2.engine import default_argument_parser +from detectron2.modeling import build_model +from detectron2.utils.analysis import ( + FlopCountAnalysis, + activation_count_operators, + parameter_count_table, +) +from detectron2.utils.logger import setup_logger + +logger = logging.getLogger("detectron2") + + +def setup(args): + if args.config_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.DATALOADER.NUM_WORKERS = 0 + cfg.merge_from_list(args.opts) + cfg.freeze() + else: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + setup_logger(name="fvcore") + setup_logger() + return cfg + + +def do_flop(cfg): + if isinstance(cfg, CfgNode): + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + data_loader = instantiate(cfg.dataloader.test) + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + model.eval() + + counts = Counter() + total_flops = [] + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa + flops = FlopCountAnalysis(model, data) + if idx > 0: + flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) + counts += flops.by_operator() + total_flops.append(flops.total()) + + logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) + logger.info( + "Average GFlops for each type of operators:\n" + + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) + ) + logger.info( + "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) + ) + + +def do_activation(cfg): + if isinstance(cfg, CfgNode): + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + data_loader = instantiate(cfg.dataloader.test) + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + model.eval() + + counts = Counter() + total_activations = [] + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa + count = activation_count_operators(model, data) + counts += count + total_activations.append(sum(count.values())) + logger.info( + "(Million) Activations for Each Type of Operators:\n" + + str([(k, v / idx) for k, v in counts.items()]) + ) + logger.info( + "Total (Million) Activations: {}±{}".format( + np.mean(total_activations), np.std(total_activations) + ) + ) + + +def do_parameter(cfg): + if isinstance(cfg, CfgNode): + model = build_model(cfg) + else: + model = instantiate(cfg.model) + logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) + + +def do_structure(cfg): + if isinstance(cfg, CfgNode): + model = build_model(cfg) + else: + model = instantiate(cfg.model) + logger.info("Model Structure:\n" + str(model)) + + +if __name__ == "__main__": + parser = default_argument_parser( + epilog=""" +Examples: + +To show parameters of a model: +$ ./analyze_model.py --tasks parameter \\ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml + +Flops and activations are data-dependent, therefore inputs and model weights +are needed to count them: + +$ ./analyze_model.py --num-inputs 100 --tasks flop \\ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ + MODEL.WEIGHTS /path/to/model.pkl +""" + ) + parser.add_argument( + "--tasks", + choices=["flop", "activation", "parameter", "structure"], + required=True, + nargs="+", + ) + parser.add_argument( + "-n", + "--num-inputs", + default=100, + type=int, + help="number of inputs used to compute statistics for flops/activations, " + "both are data dependent.", + ) + args = parser.parse_args() + assert not args.eval_only + assert args.num_gpus == 1 + + cfg = setup(args) + + for task in args.tasks: + { + "flop": do_flop, + "activation": do_activation, + "parameter": do_parameter, + "structure": do_structure, + }[task](cfg) diff --git a/data_processing/detectron2/tools/benchmark.py b/data_processing/detectron2/tools/benchmark.py new file mode 100644 index 0000000..c2d673f --- /dev/null +++ b/data_processing/detectron2/tools/benchmark.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +A script to benchmark builtin models. + +Note: this script has an extra dependency of psutil. +""" + +import itertools +import logging +import psutil +import torch +import tqdm +from fvcore.common.timer import Timer +from torch.nn.parallel import DistributedDataParallel + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, get_cfg, instantiate +from detectron2.data import ( + DatasetFromList, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.data.benchmark import DataLoaderBenchmark +from detectron2.engine import AMPTrainer, SimpleTrainer, default_argument_parser, hooks, launch +from detectron2.modeling import build_model +from detectron2.solver import build_optimizer +from detectron2.utils import comm +from detectron2.utils.collect_env import collect_env_info +from detectron2.utils.events import CommonMetricPrinter +from detectron2.utils.logger import setup_logger + +logger = logging.getLogger("detectron2") + + +def setup(args): + if args.config_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway. + cfg.merge_from_list(args.opts) + cfg.freeze() + else: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + setup_logger(distributed_rank=comm.get_rank()) + return cfg + + +def create_data_benchmark(cfg, args): + if args.config_file.endswith(".py"): + dl_cfg = cfg.dataloader.train + dl_cfg._target_ = DataLoaderBenchmark + return instantiate(dl_cfg) + else: + kwargs = build_detection_train_loader.from_config(cfg) + kwargs.pop("aspect_ratio_grouping", None) + kwargs["_target_"] = DataLoaderBenchmark + return instantiate(kwargs) + + +def RAM_msg(): + vram = psutil.virtual_memory() + return "RAM Usage: {:.2f}/{:.2f} GB".format( + (vram.total - vram.available) / 1024**3, vram.total / 1024**3 + ) + + +def benchmark_data(args): + cfg = setup(args) + logger.info("After spawning " + RAM_msg()) + + benchmark = create_data_benchmark(cfg, args) + benchmark.benchmark_distributed(250, 10) + # test for a few more rounds + for k in range(10): + logger.info(f"Iteration {k} " + RAM_msg()) + benchmark.benchmark_distributed(250, 1) + + +def benchmark_data_advanced(args): + # benchmark dataloader with more details to help analyze performance bottleneck + cfg = setup(args) + benchmark = create_data_benchmark(cfg, args) + + if comm.get_rank() == 0: + benchmark.benchmark_dataset(100) + benchmark.benchmark_mapper(100) + benchmark.benchmark_workers(100, warmup=10) + benchmark.benchmark_IPC(100, warmup=10) + if comm.get_world_size() > 1: + benchmark.benchmark_distributed(100) + logger.info("Rerun ...") + benchmark.benchmark_distributed(100) + + +def benchmark_train(args): + cfg = setup(args) + model = build_model(cfg) + logger.info("Model:\n{}".format(model)) + if comm.get_world_size() > 1: + model = DistributedDataParallel( + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False + ) + optimizer = build_optimizer(cfg, model) + checkpointer = DetectionCheckpointer(model, optimizer=optimizer) + checkpointer.load(cfg.MODEL.WEIGHTS) + + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 2 + data_loader = build_detection_train_loader(cfg) + dummy_data = list(itertools.islice(data_loader, 100)) + + def f(): + data = DatasetFromList(dummy_data, copy=False, serialize=False) + while True: + yield from data + + max_iter = 400 + trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, f(), optimizer) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.PeriodicWriter([CommonMetricPrinter(max_iter)]), + hooks.TorchProfiler( + lambda trainer: trainer.iter == max_iter - 1, cfg.OUTPUT_DIR, save_tensorboard=True + ), + ] + ) + trainer.train(1, max_iter) + + +@torch.no_grad() +def benchmark_eval(args): + cfg = setup(args) + if args.config_file.endswith(".yaml"): + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 0 + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + else: + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + + cfg.dataloader.num_workers = 0 + data_loader = instantiate(cfg.dataloader.test) + + model.eval() + logger.info("Model:\n{}".format(model)) + dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False) + + def f(): + while True: + yield from dummy_data + + for k in range(5): # warmup + model(dummy_data[k]) + + max_iter = 300 + timer = Timer() + with tqdm.tqdm(total=max_iter) as pbar: + for idx, d in enumerate(f()): + if idx == max_iter: + break + model(d) + pbar.update() + logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--task", choices=["train", "eval", "data", "data_advanced"], required=True) + args = parser.parse_args() + assert not args.eval_only + + logger.info("Environment info:\n" + collect_env_info()) + if "data" in args.task: + print("Initial " + RAM_msg()) + if args.task == "data": + f = benchmark_data + if args.task == "data_advanced": + f = benchmark_data_advanced + elif args.task == "train": + """ + Note: training speed may not be representative. + The training cost of a R-CNN model varies with the content of the data + and the quality of the model. + """ + f = benchmark_train + elif args.task == "eval": + f = benchmark_eval + # only benchmark single-GPU inference. + assert args.num_gpus == 1 and args.num_machines == 1 + launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,)) diff --git a/data_processing/detectron2/tools/convert-torchvision-to-d2.py b/data_processing/detectron2/tools/convert-torchvision-to-d2.py new file mode 100644 index 0000000..4b827d9 --- /dev/null +++ b/data_processing/detectron2/tools/convert-torchvision-to-d2.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import pickle as pkl +import sys +import torch + +""" +Usage: + # download one of the ResNet{18,34,50,101,152} models from torchvision: + wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth + # run the conversion + ./convert-torchvision-to-d2.py r50.pth r50.pkl + + # Then, use r50.pkl with the following changes in config: + +MODEL: + WEIGHTS: "/path/to/r50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" + + These models typically produce slightly worse results than the + pre-trained ResNets we use in official configs, which are the + original ResNet models released by MSRA. +""" + +if __name__ == "__main__": + input = sys.argv[1] + + obj = torch.load(input, map_location="cpu") + + newmodel = {} + for k in list(obj.keys()): + old_k = k + if "layer" not in k: + k = "stem." + k + for t in [1, 2, 3, 4]: + k = k.replace("layer{}".format(t), "res{}".format(t + 1)) + for t in [1, 2, 3]: + k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + print(old_k, "->", k) + newmodel[k] = obj.pop(old_k).detach().numpy() + + res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} + + with open(sys.argv[2], "wb") as f: + pkl.dump(res, f) + if obj: + print("Unconverted keys:", obj.keys()) diff --git a/data_processing/detectron2/tools/deploy/CMakeLists.txt b/data_processing/detectron2/tools/deploy/CMakeLists.txt new file mode 100644 index 0000000..80dae12 --- /dev/null +++ b/data_processing/detectron2/tools/deploy/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# See https://pytorch.org/tutorials/advanced/cpp_frontend.html +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) +project(torchscript_mask_rcnn) + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +find_package(TorchVision REQUIRED) # needed by export-method=tracing/scripting + +add_executable(torchscript_mask_rcnn torchscript_mask_rcnn.cpp) +target_link_libraries( + torchscript_mask_rcnn + -Wl,--no-as-needed TorchVision::TorchVision -Wl,--as-needed + "${TORCH_LIBRARIES}" ${OpenCV_LIBS}) +set_property(TARGET torchscript_mask_rcnn PROPERTY CXX_STANDARD 14) diff --git a/data_processing/detectron2/tools/deploy/README.md b/data_processing/detectron2/tools/deploy/README.md new file mode 100644 index 0000000..e33cbeb --- /dev/null +++ b/data_processing/detectron2/tools/deploy/README.md @@ -0,0 +1,66 @@ +See [deployment tutorial](https://detectron2.readthedocs.io/tutorials/deployment.html) +for some high-level background about deployment. + +This directory contains the following examples: + +1. An example script `export_model.py` + that exports a detectron2 model for deployment using different methods and formats. + +2. A C++ example that runs inference with Mask R-CNN model in TorchScript format. + +## Build +Deployment depends on libtorch and OpenCV. Some require more dependencies: + +* Running TorchScript-format models produced by `--export-method=caffe2_tracing` requires libtorch + to be built with caffe2 enabled. +* Running TorchScript-format models produced by `--export-method=tracing/scripting` requires libtorchvision (C++ library of torchvision). + +All methods are supported in one C++ file that requires all the above dependencies. +Adjust it and remove code you don't need. +As a reference, we provide a [Dockerfile](../../docker/deploy.Dockerfile) that installs all the above dependencies and builds the C++ example. + +## Use + +We show a few example commands to export and execute a Mask R-CNN model in C++. + +* `export-method=tracing, format=torchscript`: +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method tracing --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + MODEL.DEVICE cuda + +./build/torchscript_mask_rcnn output/model.ts input.jpg tracing +``` + +* `export-method=scripting, format=torchscript`: +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method scripting --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + +./build/torchscript_mask_rcnn output/model.ts input.jpg scripting +``` + +* `export-method=caffe2_tracing, format=torchscript`: + +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method caffe2_tracing --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + +./build/torchscript_mask_rcnn output/model.ts input.jpg caffe2_tracing +``` + + +## Notes: + +1. Tracing/Caffe2-tracing requires valid weights & sample inputs. + Therefore the above commands require pre-trained models and [COCO dataset](https://detectron2.readthedocs.io/tutorials/builtin_datasets.html). + You can modify the script to obtain sample inputs in other ways instead of from COCO. + +2. `--run-eval` is implemented only for tracing mode + to evaluate the exported model using the dataset in the config. + It's recommended to always verify the accuracy in case the conversion is not successful. + Evaluation can be slow if model is exported to CPU or dataset is too large ("coco_2017_val_100" is a small subset of COCO useful for evaluation). + `caffe2_tracing` accuracy may be slightly different (within 0.1 AP) from original model due to numerical precisions between different runtime. diff --git a/data_processing/detectron2/tools/deploy/export_model.py b/data_processing/detectron2/tools/deploy/export_model.py new file mode 100644 index 0000000..f507dff --- /dev/null +++ b/data_processing/detectron2/tools/deploy/export_model.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import os +from typing import Dict, List, Tuple +import torch +from torch import Tensor, nn + +import detectron2.data.transforms as T +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import build_detection_test_loader, detection_utils +from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format +from detectron2.export import ( + STABLE_ONNX_OPSET_VERSION, + TracingAdapter, + dump_torchscript_IR, + scripting_with_instances, +) +from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model +from detectron2.modeling.postprocessing import detector_postprocess +from detectron2.projects.point_rend import add_pointrend_config +from detectron2.structures import Boxes +from detectron2.utils.env import TORCH_VERSION +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + + +def setup_cfg(args): + cfg = get_cfg() + # cuda context is initialized before creating dataloader, so we don't fork anymore + cfg.DATALOADER.NUM_WORKERS = 0 + add_pointrend_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + return cfg + + +def export_caffe2_tracing(cfg, torch_model, inputs): + from detectron2.export import Caffe2Tracer + + tracer = Caffe2Tracer(cfg, torch_model, inputs) + if args.format == "caffe2": + caffe2_model = tracer.export_caffe2() + caffe2_model.save_protobuf(args.output) + # draw the caffe2 graph + caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) + return caffe2_model + elif args.format == "onnx": + import onnx + + onnx_model = tracer.export_onnx() + onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) + elif args.format == "torchscript": + ts_model = tracer.export_torchscript() + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + + +# experimental. API not yet final +def export_scripting(torch_model): + assert TORCH_VERSION >= (1, 8) + fields = { + "proposal_boxes": Boxes, + "objectness_logits": Tensor, + "pred_boxes": Boxes, + "scores": Tensor, + "pred_classes": Tensor, + "pred_masks": Tensor, + "pred_keypoints": torch.Tensor, + "pred_keypoint_heatmaps": torch.Tensor, + } + assert args.format == "torchscript", "Scripting only supports torchscript format." + + class ScriptableAdapterBase(nn.Module): + # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 + # by not retuning instances but dicts. Otherwise the exported model is not deployable + def __init__(self): + super().__init__() + self.model = torch_model + self.eval() + + if isinstance(torch_model, GeneralizedRCNN): + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + instances = self.model.inference(inputs, do_postprocess=False) + return [i.get_fields() for i in instances] + + else: + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + instances = self.model(inputs) + return [i.get_fields() for i in instances] + + ts_model = scripting_with_instances(ScriptableAdapter(), fields) + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + # TODO inference in Python now missing postprocessing glue code + return None + + +# experimental. API not yet final +def export_tracing(torch_model, inputs): + assert TORCH_VERSION >= (1, 8) + image = inputs[0]["image"] + inputs = [{"image": image}] # remove other unused keys + + if isinstance(torch_model, GeneralizedRCNN): + + def inference(model, inputs): + # use do_postprocess=False so it returns ROI mask + inst = model.inference(inputs, do_postprocess=False)[0] + return [{"instances": inst}] + + else: + inference = None # assume that we just call the model directly + + traceable_model = TracingAdapter(torch_model, inputs, inference) + + if args.format == "torchscript": + ts_model = torch.jit.trace(traceable_model, (image,)) + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + elif args.format == "onnx": + with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: + torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION) + logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) + logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) + + if args.format != "torchscript": + return None + if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): + return None + + def eval_wrapper(inputs): + """ + The exported model does not contain the final resize step, which is typically + unused in deployment but needed for evaluation. We add it manually here. + """ + input = inputs[0] + instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] + postprocessed = detector_postprocess(instances, input["height"], input["width"]) + return [{"instances": postprocessed}] + + return eval_wrapper + + +def get_sample_inputs(args): + + if args.sample_image is None: + # get a first batch from dataset + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + first_batch = next(iter(data_loader)) + return first_batch + else: + # get a sample data + original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) + # Do same preprocessing as DefaultPredictor + aug = T.ResizeShortestEdge( + [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST + ) + height, width = original_image.shape[:2] + image = aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + inputs = {"image": image, "height": height, "width": width} + + # Sample ready + sample_inputs = [inputs] + return sample_inputs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export a model for deployment.") + parser.add_argument( + "--format", + choices=["caffe2", "onnx", "torchscript"], + help="output format", + default="torchscript", + ) + parser.add_argument( + "--export-method", + choices=["caffe2_tracing", "tracing", "scripting"], + help="Method to export models", + default="tracing", + ) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") + parser.add_argument("--run-eval", action="store_true") + parser.add_argument("--output", help="output directory for the converted model") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + args = parser.parse_args() + logger = setup_logger() + logger.info("Command line arguments: " + str(args)) + PathManager.mkdirs(args.output) + # Disable re-specialization on new shapes. Otherwise --run-eval will be slow + torch._C._jit_set_bailout_depth(1) + + cfg = setup_cfg(args) + + # create a torch model + torch_model = build_model(cfg) + DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) + torch_model.eval() + + # convert and save model + if args.export_method == "caffe2_tracing": + sample_inputs = get_sample_inputs(args) + exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) + elif args.export_method == "scripting": + exported_model = export_scripting(torch_model) + elif args.export_method == "tracing": + sample_inputs = get_sample_inputs(args) + exported_model = export_tracing(torch_model, sample_inputs) + + # run evaluation with the converted model + if args.run_eval: + assert exported_model is not None, ( + "Python inference is not yet implemented for " + f"export_method={args.export_method}, format={args.format}." + ) + logger.info("Running evaluation ... this takes a long time if you export to CPU.") + dataset = cfg.DATASETS.TEST[0] + data_loader = build_detection_test_loader(cfg, dataset) + # NOTE: hard-coded evaluator. change to the evaluator for your dataset + evaluator = COCOEvaluator(dataset, output_dir=args.output) + metrics = inference_on_dataset(exported_model, data_loader, evaluator) + print_csv_format(metrics) + logger.info("Success.") diff --git a/data_processing/detectron2/tools/deploy/torchscript_mask_rcnn.cpp b/data_processing/detectron2/tools/deploy/torchscript_mask_rcnn.cpp new file mode 100644 index 0000000..fd6e1e9 --- /dev/null +++ b/data_processing/detectron2/tools/deploy/torchscript_mask_rcnn.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// @lint-ignore-every CLANGTIDY +// This is an example code that demonstrates how to run inference +// with a torchscript format Mask R-CNN model exported by ./export_model.py +// using export method=tracing, caffe2_tracing & scripting. + +#include +#include +#include + +#include +#include +#include +#include + +// only needed for export_method=tracing +#include // @oss-only +// @fb-only: #include + +using namespace std; + +c10::IValue get_caffe2_tracing_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + // FPN models require divisibility of 32. + // Tracing mode does padding inside the graph, but caffe2_tracing does not. + assert(height % 32 == 0 && width % 32 == 0); + const int channels = 3; + + auto input = + torch::from_blob(img.data, {1, height, width, channels}, torch::kUInt8); + // NHWC to NCHW + input = input.to(device, torch::kFloat).permute({0, 3, 1, 2}).contiguous(); + + std::array im_info_data{height * 1.0f, width * 1.0f, 1.0f}; + auto im_info = + torch::from_blob(im_info_data.data(), {1, 3}).clone().to(device); + return std::make_tuple(input, im_info); +} + +c10::IValue get_tracing_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + const int channels = 3; + + auto input = + torch::from_blob(img.data, {height, width, channels}, torch::kUInt8); + // HWC to CHW + input = input.to(device, torch::kFloat).permute({2, 0, 1}).contiguous(); + return input; +} + +// create a Tuple[Dict[str, Tensor]] which is the input type of scripted model +c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + const int channels = 3; + + auto img_tensor = + torch::from_blob(img.data, {height, width, channels}, torch::kUInt8); + // HWC to CHW + img_tensor = + img_tensor.to(device, torch::kFloat).permute({2, 0, 1}).contiguous(); + auto dic = c10::Dict(); + dic.insert("image", img_tensor); + return std::make_tuple(dic); +} + +c10::IValue +get_inputs(std::string export_method, cv::Mat& img, c10::Device device) { + // Given an image, create inputs in the format required by the model. + if (export_method == "tracing") + return get_tracing_inputs(img, device); + if (export_method == "caffe2_tracing") + return get_caffe2_tracing_inputs(img, device); + if (export_method == "scripting") + return get_scripting_inputs(img, device); + abort(); +} + +struct MaskRCNNOutputs { + at::Tensor pred_boxes, pred_classes, pred_masks, scores; + int num_instances() const { + return pred_boxes.sizes()[0]; + } +}; + +MaskRCNNOutputs get_outputs(std::string export_method, c10::IValue outputs) { + // Given outputs of the model, extract tensors from it to turn into a + // common MaskRCNNOutputs format. + if (export_method == "tracing") { + auto out_tuple = outputs.toTuple()->elements(); + // They are ordered alphabetically by their field name in Instances + return MaskRCNNOutputs{ + out_tuple[0].toTensor(), + out_tuple[1].toTensor(), + out_tuple[2].toTensor(), + out_tuple[3].toTensor()}; + } + if (export_method == "caffe2_tracing") { + auto out_tuple = outputs.toTuple()->elements(); + // A legacy order used by caffe2 models + return MaskRCNNOutputs{ + out_tuple[0].toTensor(), + out_tuple[2].toTensor(), + out_tuple[3].toTensor(), + out_tuple[1].toTensor()}; + } + if (export_method == "scripting") { + // With the ScriptableAdapter defined in export_model.py, the output is + // List[Dict[str, Any]]. + auto out_dict = outputs.toList().get(0).toGenericDict(); + return MaskRCNNOutputs{ + out_dict.at("pred_boxes").toTensor(), + out_dict.at("pred_classes").toTensor(), + out_dict.at("pred_masks").toTensor(), + out_dict.at("scores").toTensor()}; + } + abort(); +} + +int main(int argc, const char* argv[]) { + if (argc != 4) { + cerr << R"xx( +Usage: + ./torchscript_mask_rcnn model.ts input.jpg EXPORT_METHOD + + EXPORT_METHOD can be "tracing", "caffe2_tracing" or "scripting". +)xx"; + return 1; + } + std::string image_file = argv[2]; + std::string export_method = argv[3]; + assert( + export_method == "caffe2_tracing" || export_method == "tracing" || + export_method == "scripting"); + + torch::jit::FusionStrategy strat = {{torch::jit::FusionBehavior::DYNAMIC, 1}}; + torch::jit::setFusionStrategy(strat); + torch::autograd::AutoGradMode guard(false); + auto module = torch::jit::load(argv[1]); + + assert(module.buffers().size() > 0); + // Assume that the entire model is on the same device. + // We just put input to this device. + auto device = (*begin(module.buffers())).device(); + + cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR); + auto inputs = get_inputs(export_method, input_img, device); + + // Run the network + auto output = module.forward({inputs}); + if (device.is_cuda()) + c10::cuda::getCurrentCUDAStream().synchronize(); + + // run 3 more times to benchmark + int N_benchmark = 3, N_warmup = 1; + auto start_time = chrono::high_resolution_clock::now(); + for (int i = 0; i < N_benchmark + N_warmup; ++i) { + if (i == N_warmup) + start_time = chrono::high_resolution_clock::now(); + output = module.forward({inputs}); + if (device.is_cuda()) + c10::cuda::getCurrentCUDAStream().synchronize(); + } + auto end_time = chrono::high_resolution_clock::now(); + auto ms = chrono::duration_cast(end_time - start_time) + .count(); + cout << "Latency (should vary with different inputs): " + << ms * 1.0 / 1e6 / N_benchmark << " seconds" << endl; + + // Parse Mask R-CNN outputs + auto rcnn_outputs = get_outputs(export_method, output); + cout << "Number of detected objects: " << rcnn_outputs.num_instances() + << endl; + + cout << "pred_boxes: " << rcnn_outputs.pred_boxes.toString() << " " + << rcnn_outputs.pred_boxes.sizes() << endl; + cout << "scores: " << rcnn_outputs.scores.toString() << " " + << rcnn_outputs.scores.sizes() << endl; + cout << "pred_classes: " << rcnn_outputs.pred_classes.toString() << " " + << rcnn_outputs.pred_classes.sizes() << endl; + cout << "pred_masks: " << rcnn_outputs.pred_masks.toString() << " " + << rcnn_outputs.pred_masks.sizes() << endl; + + cout << rcnn_outputs.pred_boxes << endl; + return 0; +} diff --git a/data_processing/detectron2/tools/lazyconfig_train_net.py b/data_processing/detectron2/tools/lazyconfig_train_net.py new file mode 100644 index 0000000..bb62d36 --- /dev/null +++ b/data_processing/detectron2/tools/lazyconfig_train_net.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. + +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. + +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" +import logging + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + AMPTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/tools/lightning_train_net.py b/data_processing/detectron2/tools/lightning_train_net.py new file mode 100644 index 0000000..7a8c5d8 --- /dev/null +++ b/data_processing/detectron2/tools/lightning_train_net.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# Lightning Trainer should be considered beta at this point +# We have confirmed that training and validation run correctly and produce correct results +# Depending on how you launch the trainer, there are issues with processes terminating correctly +# This module is still dependent on D2 logging, but could be transferred to use Lightning logging + +import logging +import os +import time +import weakref +from collections import OrderedDict +from typing import Any, Dict, List +import pytorch_lightning as pl # type: ignore +from pytorch_lightning import LightningDataModule, LightningModule + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import build_detection_test_loader, build_detection_train_loader +from detectron2.engine import ( + DefaultTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, +) +from detectron2.evaluation import print_csv_format +from detectron2.evaluation.testing import flatten_results_dict +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +from detectron2.utils.events import EventStorage +from detectron2.utils.logger import setup_logger + +from train_net import build_evaluator + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("detectron2") + + +class TrainingModule(LightningModule): + def __init__(self, cfg): + super().__init__() + if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 + setup_logger() + self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + self.storage: EventStorage = None + self.model = build_model(self.cfg) + + self.start_iter = 0 + self.max_iter = cfg.SOLVER.MAX_ITER + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint["iteration"] = self.storage.iter + + def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: + self.start_iter = checkpointed_state["iteration"] + self.storage.iter = self.start_iter + + def setup(self, stage: str): + if self.cfg.MODEL.WEIGHTS: + self.checkpointer = DetectionCheckpointer( + # Assume you want to save checkpoints together with logs/statistics + self.model, + self.cfg.OUTPUT_DIR, + ) + logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") + # Only load weights, use lightning checkpointing if you want to resume + self.checkpointer.load(self.cfg.MODEL.WEIGHTS) + + self.iteration_timer = hooks.IterationTimer() + self.iteration_timer.before_train() + self.data_start = time.perf_counter() + self.writers = None + + def training_step(self, batch, batch_idx): + data_time = time.perf_counter() - self.data_start + # Need to manually enter/exit since trainer may launch processes + # This ideally belongs in setup, but setup seems to run before processes are spawned + if self.storage is None: + self.storage = EventStorage(0) + self.storage.__enter__() + self.iteration_timer.trainer = weakref.proxy(self) + self.iteration_timer.before_step() + self.writers = ( + default_writers(self.cfg.OUTPUT_DIR, self.max_iter) + if comm.is_main_process() + else {} + ) + + loss_dict = self.model(batch) + SimpleTrainer.write_metrics(loss_dict, data_time) + + opt = self.optimizers() + self.storage.put_scalar( + "lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False + ) + self.iteration_timer.after_step() + self.storage.step() + # A little odd to put before step here, but it's the best way to get a proper timing + self.iteration_timer.before_step() + + if self.storage.iter % 20 == 0: + for writer in self.writers: + writer.write() + return sum(loss_dict.values()) + + def training_step_end(self, training_step_outpus): + self.data_start = time.perf_counter() + return training_step_outpus + + def training_epoch_end(self, training_step_outputs): + self.iteration_timer.after_train() + if comm.is_main_process(): + self.checkpointer.save("model_final") + for writer in self.writers: + writer.write() + writer.close() + self.storage.__exit__(None, None, None) + + def _process_dataset_evaluation_results(self) -> OrderedDict: + results = OrderedDict() + for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): + results[dataset_name] = self._evaluators[idx].evaluate() + if comm.is_main_process(): + print_csv_format(results[dataset_name]) + + if len(results) == 1: + results = list(results.values())[0] + return results + + def _reset_dataset_evaluators(self): + self._evaluators = [] + for dataset_name in self.cfg.DATASETS.TEST: + evaluator = build_evaluator(self.cfg, dataset_name) + evaluator.reset() + self._evaluators.append(evaluator) + + def on_validation_epoch_start(self, _outputs): + self._reset_dataset_evaluators() + + def validation_epoch_end(self, _outputs): + results = self._process_dataset_evaluation_results(_outputs) + + flattened_results = flatten_results_dict(results) + for k, v in flattened_results.items(): + try: + v = float(v) + except Exception as e: + raise ValueError( + "[EvalHook] eval_function should return a nested dict of float. " + "Got '{}: {}' instead.".format(k, v) + ) from e + self.storage.put_scalars(**flattened_results, smoothing_hint=False) + + def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: + if not isinstance(batch, List): + batch = [batch] + outputs = self.model(batch) + self._evaluators[dataloader_idx].process(batch, outputs) + + def configure_optimizers(self): + optimizer = build_optimizer(self.cfg, self.model) + self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) + scheduler = build_lr_scheduler(self.cfg, optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + +class DataModule(LightningDataModule): + def __init__(self, cfg): + super().__init__() + self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + + def train_dataloader(self): + return build_detection_train_loader(self.cfg) + + def val_dataloader(self): + dataloaders = [] + for dataset_name in self.cfg.DATASETS.TEST: + dataloaders.append(build_detection_test_loader(self.cfg, dataset_name)) + return dataloaders + + +def main(args): + cfg = setup(args) + train(cfg, args) + + +def train(cfg, args): + trainer_params = { + # training loop is bounded by max steps, use a large max_epochs to make + # sure max_steps is met first + "max_epochs": 10**8, + "max_steps": cfg.SOLVER.MAX_ITER, + "val_check_interval": cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else 10**8, + "num_nodes": args.num_machines, + "gpus": args.num_gpus, + "num_sanity_val_steps": 0, + } + if cfg.SOLVER.AMP.ENABLED: + trainer_params["precision"] = 16 + + last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") + if args.resume: + # resume training from checkpoint + trainer_params["resume_from_checkpoint"] = last_checkpoint + logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") + + trainer = pl.Trainer(**trainer_params) + logger.info(f"start to train with {args.num_machines} nodes and {args.num_gpus} GPUs") + + module = TrainingModule(cfg) + data_module = DataModule(cfg) + if args.eval_only: + logger.info("Running inference") + trainer.validate(module, data_module) + else: + logger.info("Running training") + trainer.fit(module, data_module) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + logger.info("Command Line Args:", args) + main(args) diff --git a/data_processing/detectron2/tools/plain_train_net.py b/data_processing/detectron2/tools/plain_train_net.py new file mode 100644 index 0000000..be4588e --- /dev/null +++ b/data_processing/detectron2/tools/plain_train_net.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Detectron2 training script with a plain training loop. + +This script reads a given config file and runs the training or evaluation. +It is an entry point that is able to train standard models in detectron2. + +In order to let one script support training of many models, +this script contains logic that are specific to these built-in models and therefore +may not be suitable for your own project. +For example, your research project perhaps only needs a single "evaluator". + +Therefore, we recommend you to use detectron2 as a library and take +this file as an example of how to use the library. +You may want to write your own script with your datasets and other customizations. + +Compared to "train_net.py", this script supports fewer default features. +It also includes fewer abstraction, therefore is easier to add custom logic. +""" + +import logging +import os +from collections import OrderedDict +import torch +from torch.nn.parallel import DistributedDataParallel + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer +from detectron2.config import get_cfg +from detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.engine import default_argument_parser, default_setup, default_writers, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + LVISEvaluator, + PascalVOCDetectionEvaluator, + SemSegEvaluator, + inference_on_dataset, + print_csv_format, +) +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +from detectron2.utils.events import EventStorage + +logger = logging.getLogger("detectron2") + + +def get_evaluator(cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: + evaluator_list.append( + SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + ) + if evaluator_type in ["coco", "coco_panoptic_seg"]: + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if evaluator_type == "coco_panoptic_seg": + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_instance": + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + return CityscapesSemSegEvaluator(dataset_name) + if evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + if evaluator_type == "lvis": + return LVISEvaluator(dataset_name, cfg, True, output_folder) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) + ) + if len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + +def do_test(cfg, model): + results = OrderedDict() + for dataset_name in cfg.DATASETS.TEST: + data_loader = build_detection_test_loader(cfg, dataset_name) + evaluator = get_evaluator( + cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) + ) + results_i = inference_on_dataset(model, data_loader, evaluator) + results[dataset_name] = results_i + if comm.is_main_process(): + logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + print_csv_format(results_i) + if len(results) == 1: + results = list(results.values())[0] + return results + + +def do_train(cfg, model, resume=False): + model.train() + optimizer = build_optimizer(cfg, model) + scheduler = build_lr_scheduler(cfg, optimizer) + + checkpointer = DetectionCheckpointer( + model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler + ) + start_iter = ( + checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + ) + max_iter = cfg.SOLVER.MAX_ITER + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter + ) + + writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else [] + + # compared to "train_net.py", we do not support accurate timing and + # precise BN here, because they are not trivial to implement in a small training loop + data_loader = build_detection_train_loader(cfg) + logger.info("Starting training from iteration {}".format(start_iter)) + with EventStorage(start_iter) as storage: + for data, iteration in zip(data_loader, range(start_iter, max_iter)): + storage.iter = iteration + + loss_dict = model(data) + losses = sum(loss_dict.values()) + assert torch.isfinite(losses).all(), loss_dict + + loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + if comm.is_main_process(): + storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) + scheduler.step() + + if ( + cfg.TEST.EVAL_PERIOD > 0 + and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 + and iteration != max_iter - 1 + ): + do_test(cfg, model) + # Compared to "train_net.py", the test results are not dumped to EventStorage + comm.synchronize() + + if iteration - start_iter > 5 and ( + (iteration + 1) % 20 == 0 or iteration == max_iter - 1 + ): + for writer in writers: + writer.write() + periodic_checkpointer.step(iteration) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup( + cfg, args + ) # if you don't like any of the default setup, write your own setup code + return cfg + + +def main(args): + cfg = setup(args) + + model = build_model(cfg) + logger.info("Model:\n{}".format(model)) + if args.eval_only: + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + return do_test(cfg, model) + + distributed = comm.get_world_size() > 1 + if distributed: + model = DistributedDataParallel( + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False + ) + + do_train(cfg, model, resume=args.resume) + return do_test(cfg, model) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/tools/train_net.py b/data_processing/detectron2/tools/train_net.py new file mode 100644 index 0000000..8a6f297 --- /dev/null +++ b/data_processing/detectron2/tools/train_net.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +A main training script. + +This scripts reads a given config file and runs the training or evaluation. +It is an entry point that is made to train standard models in detectron2. + +In order to let one script support training of many models, +this script contains logic that are specific to these built-in models and therefore +may not be suitable for your own project. +For example, your research project perhaps only needs a single "evaluator". + +Therefore, we recommend you to use detectron2 as an library and take +this file as an example of how to use the library. +You may want to write your own script with your datasets and other customizations. +""" + +import logging +import os +from collections import OrderedDict + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + LVISEvaluator, + PascalVOCDetectionEvaluator, + SemSegEvaluator, + verify_results, +) +from detectron2.modeling import GeneralizedRCNNWithTTA + + +def build_evaluator(cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: + evaluator_list.append( + SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + ) + if evaluator_type in ["coco", "coco_panoptic_seg"]: + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if evaluator_type == "coco_panoptic_seg": + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_instance": + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + return CityscapesSemSegEvaluator(dataset_name) + elif evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + elif evaluator_type == "lvis": + return LVISEvaluator(dataset_name, output_dir=output_folder) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) + ) + elif len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains pre-defined default logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can write your + own training loop. You can use "tools/plain_train_net.py" as an example. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + return build_evaluator(cfg, dataset_name, output_folder) + + @classmethod + def test_with_TTA(cls, cfg, model): + logger = logging.getLogger("detectron2.trainer") + # In the end of training, run an evaluation with TTA + # Only support some R-CNN models. + logger.info("Running inference with test-time augmentation ...") + model = GeneralizedRCNNWithTTA(cfg, model) + evaluators = [ + cls.build_evaluator( + cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") + ) + for name in cfg.DATASETS.TEST + ] + res = cls.test(cfg, model, evaluators) + res = OrderedDict({k + "_TTA": v for k, v in res.items()}) + return res + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + """ + If you'd like to do anything fancier than the standard training logic, + consider writing your own training loop (see plain_train_net.py) or + subclassing the trainer. + """ + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + if cfg.TEST.AUG.ENABLED: + trainer.register_hooks( + [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] + ) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/data_processing/detectron2/tools/visualize_data.py b/data_processing/detectron2/tools/visualize_data.py new file mode 100644 index 0000000..fd0ba83 --- /dev/null +++ b/data_processing/detectron2/tools/visualize_data.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import os +from itertools import chain +import cv2 +import tqdm + +from detectron2.config import get_cfg +from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader +from detectron2.data import detection_utils as utils +from detectron2.data.build import filter_images_with_few_keypoints +from detectron2.utils.logger import setup_logger +from detectron2.utils.visualizer import Visualizer + + +def setup(args): + cfg = get_cfg() + if args.config_file: + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.DATALOADER.NUM_WORKERS = 0 + cfg.freeze() + return cfg + + +def parse_args(in_args=None): + parser = argparse.ArgumentParser(description="Visualize ground-truth data") + parser.add_argument( + "--source", + choices=["annotation", "dataloader"], + required=True, + help="visualize the annotations or the data loader (with pre-processing)", + ) + parser.add_argument("--config-file", metavar="FILE", help="path to config file") + parser.add_argument("--output-dir", default="./", help="path to output directory") + parser.add_argument("--show", action="store_true", help="show output in a window") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + return parser.parse_args(in_args) + + +if __name__ == "__main__": + args = parse_args() + logger = setup_logger() + logger.info("Arguments: " + str(args)) + cfg = setup(args) + + dirname = args.output_dir + os.makedirs(dirname, exist_ok=True) + metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) + + def output(vis, fname): + if args.show: + print(fname) + cv2.imshow("window", vis.get_image()[:, :, ::-1]) + cv2.waitKey() + else: + filepath = os.path.join(dirname, fname) + print("Saving to {} ...".format(filepath)) + vis.save(filepath) + + scale = 1.0 + if args.source == "dataloader": + train_data_loader = build_detection_train_loader(cfg) + for batch in train_data_loader: + for per_image in batch: + # Pytorch tensor is in (C, H, W) format + img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() + img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) + + visualizer = Visualizer(img, metadata=metadata, scale=scale) + target_fields = per_image["instances"].get_fields() + labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] + vis = visualizer.overlay_instances( + labels=labels, + boxes=target_fields.get("gt_boxes", None), + masks=target_fields.get("gt_masks", None), + keypoints=target_fields.get("gt_keypoints", None), + ) + output(vis, str(per_image["image_id"]) + ".jpg") + else: + dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN])) + if cfg.MODEL.KEYPOINT_ON: + dicts = filter_images_with_few_keypoints(dicts, 1) + for dic in tqdm.tqdm(dicts): + img = utils.read_image(dic["file_name"], "RGB") + visualizer = Visualizer(img, metadata=metadata, scale=scale) + vis = visualizer.draw_dataset_dict(dic) + output(vis, os.path.basename(dic["file_name"])) diff --git a/data_processing/detectron2/tools/visualize_json_results.py b/data_processing/detectron2/tools/visualize_json_results.py new file mode 100644 index 0000000..472190e --- /dev/null +++ b/data_processing/detectron2/tools/visualize_json_results.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import argparse +import json +import numpy as np +import os +from collections import defaultdict +import cv2 +import tqdm + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import Boxes, BoxMode, Instances +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger +from detectron2.utils.visualizer import Visualizer + + +def create_instances(predictions, image_size): + ret = Instances(image_size) + + score = np.asarray([x["score"] for x in predictions]) + chosen = (score > args.conf_threshold).nonzero()[0] + score = score[chosen] + bbox = np.asarray([predictions[i]["bbox"] for i in chosen]).reshape(-1, 4) + bbox = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) + + labels = np.asarray([dataset_id_map(predictions[i]["category_id"]) for i in chosen]) + + ret.scores = score + ret.pred_boxes = Boxes(bbox) + ret.pred_classes = labels + + try: + ret.pred_masks = [predictions[i]["segmentation"] for i in chosen] + except KeyError: + pass + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A script that visualizes the json predictions from COCO or LVIS dataset." + ) + parser.add_argument("--input", required=True, help="JSON file produced by the model") + parser.add_argument("--output", required=True, help="output directory") + parser.add_argument("--dataset", help="name of the dataset", default="coco_2017_val") + parser.add_argument("--conf-threshold", default=0.5, type=float, help="confidence threshold") + args = parser.parse_args() + + logger = setup_logger() + + with PathManager.open(args.input, "r") as f: + predictions = json.load(f) + + pred_by_image = defaultdict(list) + for p in predictions: + pred_by_image[p["image_id"]].append(p) + + dicts = list(DatasetCatalog.get(args.dataset)) + metadata = MetadataCatalog.get(args.dataset) + if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): + + def dataset_id_map(ds_id): + return metadata.thing_dataset_id_to_contiguous_id[ds_id] + + elif "lvis" in args.dataset: + # LVIS results are in the same format as COCO results, but have a different + # mapping from dataset category id to contiguous category id in [0, #categories - 1] + def dataset_id_map(ds_id): + return ds_id - 1 + + else: + raise ValueError("Unsupported dataset: {}".format(args.dataset)) + + os.makedirs(args.output, exist_ok=True) + + for dic in tqdm.tqdm(dicts): + img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1] + basename = os.path.basename(dic["file_name"]) + + predictions = create_instances(pred_by_image[dic["image_id"]], img.shape[:2]) + vis = Visualizer(img, metadata) + vis_pred = vis.draw_instance_predictions(predictions).get_image() + + vis = Visualizer(img, metadata) + vis_gt = vis.draw_dataset_dict(dic).get_image() + + concat = np.concatenate((vis_pred, vis_gt), axis=1) + cv2.imwrite(os.path.join(args.output, basename), concat[:, :, ::-1]) diff --git a/data_processing/list.txt b/data_processing/list.txt new file mode 100644 index 0000000..90d0904 --- /dev/null +++ b/data_processing/list.txt @@ -0,0 +1,882 @@ +pexels-photo-10044375_3.png +pexels-photo-10044375_6_h.png +pexels-photo-10049392_0_h.png +pexels-photo-10049570_0_h.png +pexels-photo-10057590_1.png +pexels-photo-10345681_0_s.png +pexels-photo-10375949_0_h.png +pexels-photo-10375949_0_s.png +pexels-photo-10375949_1_h.png +pexels-photo-10378628_0_s.png +pexels-photo-10397913_0_h.png +pexels-photo-10398357_1_s.png +pexels-photo-10427630_0.png +pexels-photo-10505884_0_s.png +pexels-photo-10505884_1_s.png +pexels-photo-10576766_3_s.png +pexels-photo-10576766_4.png +pexels-photo-10576766_5_h.png +pexels-photo-10576766_6_s.png +pexels-photo-10762921_2_s.png +pexels-photo-10831721_0_h.png +pexels-photo-10963534_0_s.png +pexels-photo-11034416_0_s.png +pexels-photo-11049820_0_h.png +pexels-photo-11208501_0_h.png +pexels-photo-1124837_0_s.png +pexels-photo-11256844_0_h.png +pexels-photo-11367431_0_h.png +pexels-photo-11367431_1_h.png +pexels-photo-11367431_2_h.png +pexels-photo-11370108_0_h.png +pexels-photo-11428507_0_h.png +pexels-photo-11470784_0_h.png +pexels-photo-11540389_2.png +pexels-photo-11566664_0.png +pexels-photo-11581929_0.png +pexels-photo-1168742_0.png +pexels-photo-11738318_0_s.png +pexels-photo-11745218_0_s.png +pexels-photo-1174589_0_h.png +pexels-photo-1181352_0.png +pexels-photo-1188084_0_s.png +pexels-photo-11899525_0_h.png +pexels-photo-12084380_0.png +pexels-photo-1212984_0_s.png +pexels-photo-12211769_0_h.png +pexels-photo-12421586_0_h.png +pexels-photo-12443969_0_h.png +pexels-photo-12443969_1_h.png +pexels-photo-12496864_0_s.png +pexels-photo-12513374_0_s.png +pexels-photo-12526402_0.png +pexels-photo-12536721_0_s.png +pexels-photo-12575855_0_s.png +pexels-photo-1261164_0.png +pexels-photo-12719297_1_h.png +pexels-photo-12742232_0_s.png +pexels-photo-12788467_1.png +pexels-photo-12871329_0_s.png +pexels-photo-12899798_0_s.png +pexels-photo-12920533_0_h.png +pexels-photo-12920533_1.png +pexels-photo-12920533_2_h.png +pexels-photo-12920533_3_h.png +pexels-photo-1292129_0.png +pexels-photo-13007291_0_s.png +pexels-photo-13022373_0_s.png +pexels-photo-13062548_3_s.png +pexels-photo-13086623_1.png +pexels-photo-13086623_2.png +pexels-photo-1325723_0_s.png +pexels-photo-13417785_0_h.png +pexels-photo-13417785_1.png +pexels-photo-13929227_0_h.png +pexels-photo-13929227_1_h.png +pexels-photo-13997557_3.png +pexels-photo-14025620_1_h.png +pexels-photo-14025620_11_s.png +pexels-photo-14025620_13_s.png +pexels-photo-14025620_15_h.png +pexels-photo-14025620_16_s.png +pexels-photo-14025620_17.png +pexels-photo-14025620_23_h.png +pexels-photo-14025620_26.png +pexels-photo-14025620_5_s.png +pexels-photo-1415268_0.png +pexels-photo-1420695_0_s.png +pexels-photo-1420695_1.png +pexels-photo-14235983_0_h.png +pexels-photo-14267674_0_s.png +pexels-photo-14408448_1_h.png +pexels-photo-14408448_2_s.png +pexels-photo-14514432_0_s.png +pexels-photo-14523206_0_s.png +pexels-photo-1474233_0.png +pexels-photo-1474233_1_h.png +pexels-photo-14844812_0_s.png +pexels-photo-1484796_0_h.png +pexels-photo-1486064_0_h.png +pexels-photo-1510542_0_h.png +pexels-photo-15268397_0_h.png +pexels-photo-15268397_1_h.png +pexels-photo-15304241_0_h.png +pexels-photo-15464476_0.png +pexels-photo-15498262_0_h.png +pexels-photo-15498262_3_h.png +pexels-photo-15498262_4_h.png +pexels-photo-15787353_0_s.png +pexels-photo-15797148_0_h.png +pexels-photo-15810517_0_h.png +pexels-photo-15823832_0_s.png +pexels-photo-15832913_0.png +pexels-photo-1617610_0_s.png +pexels-photo-1620788_0_h.png +pexels-photo-1620788_1_h.png +pexels-photo-1620788_3_s.png +pexels-photo-1630784_3_s.png +pexels-photo-1630784_4.png +pexels-photo-1630784_5_h.png +pexels-photo-1683974_0_h.png +pexels-photo-1757923_0_s.png +pexels-photo-1772724_0.png +pexels-photo-1843863_0_s.png +pexels-photo-1863476_0_h.png +pexels-photo-1926769_0_s.png +pexels-photo-2058608_0_s.png +pexels-photo-206402_0.png +pexels-photo-206593_0_s.png +pexels-photo-2247814_0_h.png +pexels-photo-2272941_0_h.png +pexels-photo-2343157_0_h.png +pexels-photo-2410576_0.png +pexels-photo-2430945_0_h.png +pexels-photo-2430945_1.png +pexels-photo-2430945_2.png +pexels-photo-2442399_0_s.png +pexels-photo-2539269_0_h.png +pexels-photo-2539269_1_h.png +pexels-photo-260111_0_h.png +pexels-photo-2602545_0.png +pexels-photo-2730217_0_h.png +pexels-photo-2734302_0_s.png +pexels-photo-274577_0_h.png +pexels-photo-2814239_0_s.png +pexels-photo-2853592_0_h.png +pexels-photo-2859374_0_h.png +pexels-photo-2896464_0.png +pexels-photo-290416_0_h.png +pexels-photo-2913125_0_h.png +pexels-photo-2927584_20.png +pexels-photo-2962147_0.png +pexels-photo-2976107_0_s.png +pexels-photo-3023746_0_s.png +pexels-photo-3026283_0_s.png +pexels-photo-307847_0_h.png +pexels-photo-3171067_0_h.png +pexels-photo-3182748_0_h.png +pexels-photo-3182748_2.png +pexels-photo-3182748_3_s.png +pexels-photo-3182748_4.png +pexels-photo-3182748_5_s.png +pexels-photo-3184436_0_h.png +pexels-photo-3184436_1.png +pexels-photo-3184436_10_h.png +pexels-photo-3184436_16_s.png +pexels-photo-3184436_3_h.png +pexels-photo-3184436_4_h.png +pexels-photo-319899_1.png +pexels-photo-3201696_0.png +pexels-photo-3205741_0_s.png +pexels-photo-3275945_0.png +pexels-photo-3276582_0_h.png +pexels-photo-3317750_1_s.png +pexels-photo-3352734_0.png +pexels-photo-3363968_0_s.png +pexels-photo-3527089_0_s.png +pexels-photo-3536435_2.png +pexels-photo-3564649_0_h.png +pexels-photo-3661452_0_h.png +pexels-photo-3661452_1_h.png +pexels-photo-3662649_1_h.png +pexels-photo-3703966_0_h.png +pexels-photo-3730941_0_h.png +pexels-photo-3754255_0_s.png +pexels-photo-3754255_1_s.png +pexels-photo-3755714_0_h.png +pexels-photo-3755714_1_h.png +pexels-photo-3755714_1_s.png +pexels-photo-3756785_0_h.png +pexels-photo-3758012_1_s.png +pexels-photo-3758012_2_s.png +pexels-photo-3758012_3_s.png +pexels-photo-3760923_0_h.png +pexels-photo-3763999_0_h.png +pexels-photo-3768879_0_h.png +pexels-photo-3771672_0_h.png +pexels-photo-3776847_0.png +pexels-photo-3776847_1_h.png +pexels-photo-3776847_2.png +pexels-photo-3776847_3.png +pexels-photo-3776847_5.png +pexels-photo-3777884_0_h.png +pexels-photo-3777884_1_h.png +pexels-photo-3780027_0_h.png +pexels-photo-3781911_0_h.png +pexels-photo-3783512_0_s.png +pexels-photo-37839_0.png +pexels-photo-3813041_0_h.png +pexels-photo-3817646_0_s.png +pexels-photo-3819576_0_h.png +pexels-photo-3819576_1_h.png +pexels-photo-3819950_0_h.png +pexels-photo-3820428_1.png +pexels-photo-3822724_0_s.png +pexels-photo-3822724_1_s.png +pexels-photo-3822724_2_s.png +pexels-photo-3822724_3_s.png +pexels-photo-3823490_0_h.png +pexels-photo-3845625_0_s.png +pexels-photo-3851853_1_h.png +pexels-photo-3855442_0_h.png +pexels-photo-3857525_0.png +pexels-photo-3867382_0_h.png +pexels-photo-3873029_0_s.png +pexels-photo-3933029_0_h.png +pexels-photo-3933395_1.png +pexels-photo-3933410_1_s.png +pexels-photo-3933896_1_h.png +pexels-photo-3951399_0_s.png +pexels-photo-3965391_2.png +pexels-photo-3967782_0.png +pexels-photo-3971474_0_h.png +pexels-photo-3983667_0_s.png +pexels-photo-3991771_1_s.png +pexels-photo-3992368_0_h.png +pexels-photo-3992368_1_h.png +pexels-photo-4009009_0_h.png +pexels-photo-4009592_0_s.png +pexels-photo-4009592_1_s.png +pexels-photo-4019754_0_s.png +pexels-photo-4019754_1_h.png +pexels-photo-4019754_10_h.png +pexels-photo-4019754_12_s.png +pexels-photo-4019754_2_h.png +pexels-photo-4019754_4.png +pexels-photo-4019754_8_s.png +pexels-photo-4040874_0.png +pexels-photo-4046104_0_h.png +pexels-photo-4046104_1.png +pexels-photo-4047023_0_h.png +pexels-photo-4047829_0_s.png +pexels-photo-4050392_0_h.png +pexels-photo-4057689_0.png +pexels-photo-4100421_0_h.png +pexels-photo-4100421_1.png +pexels-photo-4101187_0_h.png +pexels-photo-4101187_1_h.png +pexels-photo-4127873_1_s.png +pexels-photo-4132340_0_h.png +pexels-photo-4132358_0_h.png +pexels-photo-4153176_0_h.png +pexels-photo-4164759_0.png +pexels-photo-4260102_0.png +pexels-photo-4260102_1.png +pexels-photo-4260102_2.png +pexels-photo-4339514_0_h.png +pexels-photo-4342098_0_s.png +pexels-photo-4395319_0_s.png +pexels-photo-4473890_0.png +pexels-photo-4498150_0_s.png +pexels-photo-4510854_0.png +pexels-photo-4543732_0_h.png +pexels-photo-4543732_1_h.png +pexels-photo-4543732_2_s.png +pexels-photo-4546135_0_h.png +pexels-photo-4555327_0_s.png +pexels-photo-4555327_1_h.png +pexels-photo-4555327_2.png +pexels-photo-4571260_0_s.png +pexels-photo-4584582_0_h.png +pexels-photo-4586678_0_h.png +pexels-photo-4587421_0_h.png +pexels-photo-4623085_0.png +pexels-photo-4623525_0_h.png +pexels-photo-4623525_2.png +pexels-photo-4624913_0_h.png +pexels-photo-4624913_1.png +pexels-photo-4638830_0_s.png +pexels-photo-4668946_0_h.png +pexels-photo-4668946_2.png +pexels-photo-4672484_0_s.png +pexels-photo-4672484_3_h.png +pexels-photo-4672484_4_h.png +pexels-photo-4672484_6_s.png +pexels-photo-4720309_0_s.png +pexels-photo-4720500_0_s.png +pexels-photo-4751203_0.png +pexels-photo-4769468_0_h.png +pexels-photo-4781458_0_h.png +pexels-photo-4783338_0_s.png +pexels-photo-4842498_0_h.png +pexels-photo-4872091_0.png +pexels-photo-4872091_1_s.png +pexels-photo-4874336_0.png +pexels-photo-4874917_0_h.png +pexels-photo-4874917_1_h.png +pexels-photo-4874917_2.png +pexels-photo-4877850_0_h.png +pexels-photo-4877850_1_s.png +pexels-photo-4877850_4_s.png +pexels-photo-4881613_0_h.png +pexels-photo-4881613_1.png +pexels-photo-4881613_3_s.png +pexels-photo-4881613_4_h.png +pexels-photo-4890273_0_s.png +pexels-photo-4894830_0_h.png +pexels-photo-4911750_0_s.png +pexels-photo-4917820_0_h.png +pexels-photo-4939552_0_s.png +pexels-photo-4946531_0_h.png +pexels-photo-4977411_0_s.png +pexels-photo-4977411_1.png +pexels-photo-4980302_0_h.png +pexels-photo-4985017_0.png +pexels-photo-4989266_0_h.png +pexels-photo-5020368_1_s.png +pexels-photo-5029344_0_h.png +pexels-photo-5037007_0_h.png +pexels-photo-5037007_3_h.png +pexels-photo-5037285_0_h.png +pexels-photo-5037285_0_s.png +pexels-photo-5047063_0_h.png +pexels-photo-5055244_0_s.png +pexels-photo-5055248_0.png +pexels-photo-5055248_1_h.png +pexels-photo-5055421_0_h.png +pexels-photo-5055421_1_h.png +pexels-photo-5060987_0_h.png +pexels-photo-5063299_0.png +pexels-photo-5094096_0_h.png +pexels-photo-5094104_0.png +pexels-photo-5098287_1_s.png +pexels-photo-5098287_11_h.png +pexels-photo-5098287_14_h.png +pexels-photo-5098287_16_h.png +pexels-photo-5098287_2.png +pexels-photo-5098287_3_s.png +pexels-photo-5098287_6.png +pexels-photo-5126956_0_h.png +pexels-photo-5158233_0_s.png +pexels-photo-5205275_0_s.png +pexels-photo-5205275_1_s.png +pexels-photo-5211447_0_h.png +pexels-photo-5212668_2.png +pexels-photo-5212699_1_s.png +pexels-photo-5225446_0_h.png +pexels-photo-5239523_0_s.png +pexels-photo-5239523_1.png +pexels-photo-5240605_1_h.png +pexels-photo-5241025_0_h.png +pexels-photo-5256916_0.png +pexels-photo-5257266_0_h.png +pexels-photo-5257454_0.png +pexels-photo-5257454_1_s.png +pexels-photo-5257454_2_s.png +pexels-photo-5257454_3_h.png +pexels-photo-5257497_0_h.png +pexels-photo-5257497_0_s.png +pexels-photo-5257547_0_h.png +pexels-photo-5258251_0_s.png +pexels-photo-5258907_0_h.png +pexels-photo-5263833_0_h.png +pexels-photo-5263833_1_h.png +pexels-photo-5273059_0_h.png +pexels-photo-5274600_0_h.png +pexels-photo-5310786_0_s.png +pexels-photo-5329068_0.png +pexels-photo-5335170_0.png +pexels-photo-5349756_1.png +pexels-photo-5356823_0_h.png +pexels-photo-5357336_0.png +pexels-photo-5357615_0_s.png +pexels-photo-5366313_0.png +pexels-photo-5386148_0_s.png +pexels-photo-5386459_0_s.png +pexels-photo-5388321_0_s.png +pexels-photo-5390335_0_s.png +pexels-photo-5393445_1.png +pexels-photo-5405024_0.png +pexels-photo-5427143_0.png +pexels-photo-5439478_0_s.png +pexels-photo-5439478_1_h.png +pexels-photo-5439478_2_h.png +pexels-photo-5488943_0_s.png +pexels-photo-5490267_0_s.png +pexels-photo-5538615_0_h.png +pexels-photo-5543181_0_s.png +pexels-photo-5553671_0_s.png +pexels-photo-5555111_0_s.png +pexels-photo-5560039_0_h.png +pexels-photo-5561169_0_h.png +pexels-photo-5561455_0_h.png +pexels-photo-5593618_0_h.png +pexels-photo-5600112_0.png +pexels-photo-5622327_0_h.png +pexels-photo-5691296_0_h.png +pexels-photo-5691845_1_h.png +pexels-photo-5691845_2_h.png +pexels-photo-5692182_0_h.png +pexels-photo-5692182_1_h.png +pexels-photo-5692691_0_h.png +pexels-photo-5692691_1_h.png +pexels-photo-5692997_0_h.png +pexels-photo-5698208_0_h.png +pexels-photo-5698369_0_h.png +pexels-photo-5709530_0_h.png +pexels-photo-5710602_0_h.png +pexels-photo-5710946_0_s.png +pexels-photo-5711233_0.png +pexels-photo-5717051_0_s.png +pexels-photo-5721093_0.png +pexels-photo-5727759_0_h.png +pexels-photo-5727759_1_s.png +pexels-photo-5727759_2_s.png +pexels-photo-5727759_3_s.png +pexels-photo-5727759_4_h.png +pexels-photo-5727775_1_s.png +pexels-photo-5727775_2_s.png +pexels-photo-5727775_3_h.png +pexels-photo-5727775_4_s.png +pexels-photo-5728206_0_s.png +pexels-photo-5728206_1.png +pexels-photo-5762495_0_h.png +pexels-photo-5764902_0_s.png +pexels-photo-5764902_1.png +pexels-photo-5764903_0.png +pexels-photo-5795419_0_h.png +pexels-photo-5814298_1_h.png +pexels-photo-5814298_2_h.png +pexels-photo-5814298_3_h.png +pexels-photo-583124_0_h.png +pexels-photo-5847798_0.png +pexels-photo-5876654_0.png +pexels-photo-5896471_0.png +pexels-photo-5896471_1_h.png +pexels-photo-5905494_0_h.png +pexels-photo-5905494_1_s.png +pexels-photo-5911942_0_s.png +pexels-photo-5915298_0_h.png +pexels-photo-5917712_0_s.png +pexels-photo-5922070_0.png +pexels-photo-5933917_0_h.png +pexels-photo-5935233_0_h.png +pexels-photo-5935233_1_s.png +pexels-photo-5935233_2_h.png +pexels-photo-5940841_0_s.png +pexels-photo-5940841_1_h.png +pexels-photo-5940841_2.png +pexels-photo-5940841_3_h.png +pexels-photo-5940841_4_s.png +pexels-photo-5961074_0_s.png +pexels-photo-5961074_1_s.png +pexels-photo-5961074_2_s.png +pexels-photo-5999085_0_h.png +pexels-photo-6006255_0.png +pexels-photo-6015886_0_s.png +pexels-photo-6015935_0_s.png +pexels-photo-6023601_0_s.png +pexels-photo-6025211_3.png +pexels-photo-6039870_7_s.png +pexels-photo-6113555_0_h.png +pexels-photo-6113555_2.png +pexels-photo-6132889_0.png +pexels-photo-6140366_0.png +pexels-photo-6140366_1_h.png +pexels-photo-6140366_2.png +pexels-photo-6140723_0_s.png +pexels-photo-6141083_1_h.png +pexels-photo-6141233_0_s.png +pexels-photo-6141233_1_h.png +pexels-photo-6147015_0.png +pexels-photo-6147369_1.png +pexels-photo-6147369_2_s.png +pexels-photo-6150579_0_s.png +pexels-photo-6150579_2.png +pexels-photo-6169668_0_s.png +pexels-photo-6169668_1_s.png +pexels-photo-6193433_1_s.png +pexels-photo-6201979_0_h.png +pexels-photo-6202790_0_s.png +pexels-photo-620340_0_s.png +pexels-photo-6204234_0.png +pexels-photo-6209065_2_s.png +pexels-photo-6210267_0_h.png +pexels-photo-6210267_1.png +pexels-photo-6220702_1.png +pexels-photo-6257132_0.png +pexels-photo-6281724_0_s.png +pexels-photo-6281724_1.png +pexels-photo-6281724_2_h.png +pexels-photo-6297603_0_h.png +pexels-photo-6297603_1_s.png +pexels-photo-6299291_0_h.png +pexels-photo-6299291_1_h.png +pexels-photo-6299291_2_h.png +pexels-photo-6299291_3_h.png +pexels-photo-6311134_0_s.png +pexels-photo-6339324_1_h.png +pexels-photo-6339324_2.png +pexels-photo-634007_0.png +pexels-photo-634007_1.png +pexels-photo-6340620_0_h.png +pexels-photo-6340620_1_s.png +pexels-photo-6340620_2_h.png +pexels-photo-6340620_3.png +pexels-photo-6340620_4.png +pexels-photo-6340628_1_s.png +pexels-photo-6340628_2.png +pexels-photo-6345387_0.png +pexels-photo-6453628_0_h.png +pexels-photo-6453958_0_s.png +pexels-photo-6455834_0_s.png +pexels-photo-6457490_0_s.png +pexels-photo-6457490_1_h.png +pexels-photo-6476344_0.png +pexels-photo-6478306_0_s.png +pexels-photo-6491794_3_h.png +pexels-photo-6512495_0.png +pexels-photo-6530738_0_s.png +pexels-photo-6551237_0_h.png +pexels-photo-6551494_0_s.png +pexels-photo-6578394_0_s.png +pexels-photo-6612632_0_s.png +pexels-photo-6612632_1_h.png +pexels-photo-6620720_0_s.png +pexels-photo-6626000_0_s.png +pexels-photo-6626000_1_h.png +pexels-photo-6692899_0_h.png +pexels-photo-6694742_0_s.png +pexels-photo-6694742_1_s.png +pexels-photo-6714616_0_h.png +pexels-photo-674833_0_s.png +pexels-photo-6770359_0_s.png +pexels-photo-6774173_0_s.png +pexels-photo-6777188_2.png +pexels-photo-6777188_3.png +pexels-photo-6777188_4_h.png +pexels-photo-6781177_0.png +pexels-photo-6784855_0_h.png +pexels-photo-6784898_0_h.png +pexels-photo-6785010_0_s.png +pexels-photo-6815668_0_h.png +pexels-photo-6829484_0_s.png +pexels-photo-6835956_0.png +pexels-photo-6874028_0_s.png +pexels-photo-6874659_0_s.png +pexels-photo-6874659_1.png +pexels-photo-6878686_0_h.png +pexels-photo-6878686_1_h.png +pexels-photo-6897918_0_s.png +pexels-photo-6914062_0_h.png +pexels-photo-6914062_1.png +pexels-photo-6914062_2_h.png +pexels-photo-6914062_3_s.png +pexels-photo-6914062_4.png +pexels-photo-6930406_1.png +pexels-photo-6935992_0_h.png +pexels-photo-6941674_0_h.png +pexels-photo-6941674_1_h.png +pexels-photo-6948104_0_h.png +pexels-photo-6953854_0.png +pexels-photo-6953854_1_h.png +pexels-photo-6969970_0_h.png +pexels-photo-6975208_0_s.png +pexels-photo-6975208_1_h.png +pexels-photo-6975640_0_h.png +pexels-photo-6995845_0_h.png +pexels-photo-7010106_0_s.png +pexels-photo-7013903_0.png +pexels-photo-7020617_0_s.png +pexels-photo-7035396_0_h.png +pexels-photo-7035541_0_h.png +pexels-photo-7063751_0.png +pexels-photo-7065297_0_s.png +pexels-photo-7065297_1_s.png +pexels-photo-7065436_1_s.png +pexels-photo-7065455_0_s.png +pexels-photo-7065455_1_h.png +pexels-photo-7084410_0_s.png +pexels-photo-7084410_1.png +pexels-photo-7084418_0_s.png +pexels-photo-7104232_0_h.png +pexels-photo-7104232_1.png +pexels-photo-7148031_1.png +pexels-photo-7202771_0_s.png +pexels-photo-7202771_1_s.png +pexels-photo-7213203_0_s.png +pexels-photo-7213203_1_s.png +pexels-photo-7213203_2.png +pexels-photo-7213203_3_h.png +pexels-photo-7213366_0.png +pexels-photo-7213366_2_h.png +pexels-photo-7219206_1_h.png +pexels-photo-7232041_0_h.png +pexels-photo-7236174_0_h.png +pexels-photo-7244740_0_h.png +pexels-photo-7249421_0_s.png +pexels-photo-7249421_1_h.png +pexels-photo-7249421_2_h.png +pexels-photo-7249421_3_s.png +pexels-photo-7249421_4.png +pexels-photo-7266752_0_s.png +pexels-photo-7266752_1.png +pexels-photo-7270922_1.png +pexels-photo-7283533_0_s.png +pexels-photo-7283533_1_s.png +pexels-photo-7295889_0_h.png +pexels-photo-7296266_0_h.png +pexels-photo-7296266_1_s.png +pexels-photo-7318674_0_s.png +pexels-photo-7322106_0.png +pexels-photo-7322106_1.png +pexels-photo-7322192_0_h.png +pexels-photo-7322192_1_s.png +pexels-photo-7322489_0_s.png +pexels-photo-7322489_1_s.png +pexels-photo-7322492_0_s.png +pexels-photo-7322492_1_s.png +pexels-photo-7328434_0_h.png +pexels-photo-7328434_1_h.png +pexels-photo-7345473_0_s.png +pexels-photo-7354924_0_h.png +pexels-photo-7368179_0_h.png +pexels-photo-7368179_1_s.png +pexels-photo-7431357_0_h.png +pexels-photo-7432114_0_s.png +pexels-photo-7439284_0.png +pexels-photo-7479863_0_s.png +pexels-photo-7479863_1_s.png +pexels-photo-749072_0_h.png +pexels-photo-7490853_0.png +pexels-photo-7490853_1_s.png +pexels-photo-7495122_1_s.png +pexels-photo-7495122_2.png +pexels-photo-7500331_0_h.png +pexels-photo-7507087_0_s.png +pexels-photo-7513029_0.png +pexels-photo-7513092_0_h.png +pexels-photo-7520932_0_s.png +pexels-photo-7520932_1.png +pexels-photo-7529990_0.png +pexels-photo-7529990_1_s.png +pexels-photo-7551383_0_h.png +pexels-photo-7551383_1.png +pexels-photo-7551609_0_s.png +pexels-photo-7551609_1_h.png +pexels-photo-7569423_5_h.png +pexels-photo-7569423_7_h.png +pexels-photo-7573405_0.png +pexels-photo-7573408_0_h.png +pexels-photo-7599683_0_h.png +pexels-photo-7605899_0_s.png +pexels-photo-7607782_0_h.png +pexels-photo-7609958_0.png +pexels-photo-7624882_0_s.png +pexels-photo-7640437_0_s.png +pexels-photo-7644155_0_s.png +pexels-photo-7648029_0_s.png +pexels-photo-7648029_1_s.png +pexels-photo-7648041_1_s.png +pexels-photo-7651753_0.png +pexels-photo-7652179_0_h.png +pexels-photo-7652179_2.png +pexels-photo-7671060_0_h.png +pexels-photo-7674834_0_h.png +pexels-photo-7675912_0_h.png +pexels-photo-7676827_0.png +pexels-photo-7677597_0_h.png +pexels-photo-7677597_2_h.png +pexels-photo-7677951_0_h.png +pexels-photo-7677951_1_s.png +pexels-photo-7683876_1_s.png +pexels-photo-7698396_0_h.png +pexels-photo-7705684_1_h.png +pexels-photo-7749406_0_h.png +pexels-photo-7750690_0.png +pexels-photo-7750690_1.png +pexels-photo-7750690_2_s.png +pexels-photo-7787678_0.png +pexels-photo-7799636_0.png +pexels-photo-7799636_1_s.png +pexels-photo-7821545_0_h.png +pexels-photo-7841444_0_s.png +pexels-photo-7841467_0_h.png +pexels-photo-7841467_1.png +pexels-photo-7856901_0_s.png +pexels-photo-7867640_0_s.png +pexels-photo-7869685_0_s.png +pexels-photo-7872622_0_s.png +pexels-photo-7874479_0_h.png +pexels-photo-7876092_0.png +pexels-photo-7876092_1.png +pexels-photo-7876092_2.png +pexels-photo-7876149_0.png +pexels-photo-7876149_1_s.png +pexels-photo-7876897_0_h.png +pexels-photo-7876897_1.png +pexels-photo-7876897_2_s.png +pexels-photo-7879720_0.png +pexels-photo-7879720_1_h.png +pexels-photo-7884124_0.png +pexels-photo-7886856_0_h.png +pexels-photo-7929483_3_h.png +pexels-photo-7935813_0_h.png +pexels-photo-7935813_1.png +pexels-photo-7938036_0_h.png +pexels-photo-7938036_1_h.png +pexels-photo-7938036_2_h.png +pexels-photo-7938044_0_h.png +pexels-photo-7938044_1_h.png +pexels-photo-7938044_2_h.png +pexels-photo-7973031_0.png +pexels-photo-7973031_1_s.png +pexels-photo-7983588_0_s.png +pexels-photo-7989150_4_s.png +pexels-photo-8055523_0_h.png +pexels-photo-8057362_0.png +pexels-photo-8058287_0_h.png +pexels-photo-8083025_0.png +pexels-photo-8083025_1.png +pexels-photo-8083025_2.png +pexels-photo-8084057_1.png +pexels-photo-8084057_3_s.png +pexels-photo-8090448_0_s.png +pexels-photo-8101710_0.png +pexels-photo-8101710_1_h.png +pexels-photo-8101710_2_h.png +pexels-photo-8101710_3_h.png +pexels-photo-8104174_1_h.png +pexels-photo-8104174_3_h.png +pexels-photo-8104852_0.png +pexels-photo-8104852_1_s.png +pexels-photo-8107081_0.png +pexels-photo-8112164_0_h.png +pexels-photo-8112164_1_s.png +pexels-photo-8112164_2.png +pexels-photo-8124250_0.png +pexels-photo-8124250_1.png +pexels-photo-8133858_0.png +pexels-photo-8133858_1.png +pexels-photo-8133858_2.png +pexels-photo-8133995_0_h.png +pexels-photo-8133995_1_s.png +pexels-photo-8133995_2_h.png +pexels-photo-8153903_0_s.png +pexels-photo-8170251_0_h.png +pexels-photo-8170255_0_h.png +pexels-photo-8170255_1_h.png +pexels-photo-8170255_2.png +pexels-photo-8170255_3.png +pexels-photo-8170294_0_h.png +pexels-photo-8171211_0.png +pexels-photo-8171211_1.png +pexels-photo-818801_0_s.png +pexels-photo-8213234_0_h.png +pexels-photo-8213234_1_h.png +pexels-photo-8213977_0_s.png +pexels-photo-8259896_0_h.png +pexels-photo-8259896_1_s.png +pexels-photo-8259896_2_h.png +pexels-photo-8259896_3_s.png +pexels-photo-8307438_0_h.png +pexels-photo-8307438_1_h.png +pexels-photo-8307438_4_s.png +pexels-photo-8355405_0.png +pexels-photo-8355405_1_h.png +pexels-photo-8355405_2_s.png +pexels-photo-8355405_3_h.png +pexels-photo-8364014_0_h.png +pexels-photo-8376149_0.png +pexels-photo-8380092_2.png +pexels-photo-8380092_6_h.png +pexels-photo-8390281_0.png +pexels-photo-8390281_2_s.png +pexels-photo-8417276_0.png +pexels-photo-8417323_0_s.png +pexels-photo-8419207_0_s.png +pexels-photo-8419492_1_h.png +pexels-photo-8419492_3_s.png +pexels-photo-8422258_0_h.png +pexels-photo-8422258_1.png +pexels-photo-8424561_0.png +pexels-photo-8424561_1_s.png +pexels-photo-8429912_0_h.png +pexels-photo-8429912_1_h.png +pexels-photo-8430129_1_h.png +pexels-photo-8441414_0_s.png +pexels-photo-8471836_0_h.png +pexels-photo-8502053_0.png +pexels-photo-8512442_0_h.png +pexels-photo-8540374_0_h.png +pexels-photo-8544490_4_h.png +pexels-photo-8544490_5_h.png +pexels-photo-8612899_2_h.png +pexels-photo-8612899_7_h.png +pexels-photo-8613094_0_h.png +pexels-photo-8617765_1_h.png +pexels-photo-8617765_2_h.png +pexels-photo-8617765_3_h.png +pexels-photo-8617765_4_s.png +pexels-photo-8617765_5_h.png +pexels-photo-8633364_1_s.png +pexels-photo-8638142_0_s.png +pexels-photo-8717389_0.png +pexels-photo-8730181_0.png +pexels-photo-8730181_2_s.png +pexels-photo-8730181_4_s.png +pexels-photo-8761322_0_s.png +pexels-photo-8789214_0_s.png +pexels-photo-8790334_0_h.png +pexels-photo-8790334_1.png +pexels-photo-8790334_2_h.png +pexels-photo-8795805_0_s.png +pexels-photo-8837759_0_h.png +pexels-photo-8837759_1_h.png +pexels-photo-8837759_2.png +pexels-photo-8837759_3.png +pexels-photo-8837759_7_h.png +pexels-photo-8865140_1_h.png +pexels-photo-8872678_1.png +pexels-photo-8872678_3_h.png +pexels-photo-8872678_4_s.png +pexels-photo-8922545_1_h.png +pexels-photo-8922545_2.png +pexels-photo-8922803_0_h.png +pexels-photo-8923264_0.png +pexels-photo-8923264_1_h.png +pexels-photo-8927013_0_h.png +pexels-photo-8933955_0_h.png +pexels-photo-8941577_1_h.png +pexels-photo-8941577_2_h.png +pexels-photo-8959533_0_h.png +pexels-photo-8972512_1_h.png +pexels-photo-8986675_0_s.png +pexels-photo-9034737_1_s.png +pexels-photo-9034767_2_s.png +pexels-photo-905336_0_s.png +pexels-photo-905336_1_s.png +pexels-photo-905336_2_s.png +pexels-photo-9208549_0.png +pexels-photo-9211772_0.png +pexels-photo-9268713_0_s.png +pexels-photo-935953_0_h.png +pexels-photo-936094_0_s.png +pexels-photo-938642_1_s.png +pexels-photo-9399190_0_h.png +pexels-photo-943241_0.png +pexels-photo-9518018_0_h.png +pexels-photo-9558694_0_s.png +pexels-photo-9628357_0_h.png +pexels-photo-9756301_0_h.png +pexels-photo-975680_0_h.png +pexels-photo-9759336_0_h.png +pexels-photo-9783374_0_h.png +pexels-photo-9784059_0.png +pexels-photo-9784393_0_h.png +pexels-photo-983197_0_s.png +pexels-photo-9834555_1.png +pexels-photo-9872289_0_h.png +pexels-photo-9872289_1_h.png +pexels-photo-9872289_2_h.png +pexels-photo-9885405_1.png +pexels-photo-9885405_2_h.png +pexels-photo-9885405_3.png +pexels-photo-9907941_0_s.png +pexels-photo-993868_0_s.png +pexels-photo-9958416_0_h.png +pexels-photo-9969297_0_h.png +pexels-photo-9974985_0_h.png +pexels-photo-9987846_0_s.png +pexels-photo-9987846_1_h.png +pexels-photo-9993910_16_s.png diff --git a/data_processing/main/config.py b/data_processing/main/config.py new file mode 100644 index 0000000..d51df3c --- /dev/null +++ b/data_processing/main/config.py @@ -0,0 +1,136 @@ +import os +import os.path as osp +import sys +import numpy as np +import datetime +import yaml +import shutil +import glob +from easydict import EasyDict as edict +import time +class Config: + ## dataset + # MuCo, Human36M, MSCOCO, PW3D, FreiHAND + trainset_3d = ['Human36M'] # 'Human36M', 'MuCo' + trainset_2d = ['MSCOCO'] # 'MSCOCO', 'MPII', 'CrowdPose' + testset = 'PW3D' # 'MuPoTs' 'MSCOCO' Human36M, MSCOCO, 'PW3D' + + ## model setting + resnet_type = 50 # 50, 101, 152 + + ## input, output + input_img_shape = (256, 256) #(256, 192) + output_hm_shape = (64, 64, 64) #(64, 64, 48) + bbox_3d_size = 2 if 'FreiHAND' not in trainset_3d + trainset_2d + [testset] else 0.3 + sigma = 2.5 + focal = (5000, 5000) # virtual focal lengths + princpt = (input_img_shape[1] / 2, input_img_shape[0] / 2) # virtual principal point position + + ## training config + lr_dec_epoch = [15] if 'FreiHAND' not in trainset_3d + trainset_2d + [testset] else [17,21] + end_epoch = 20 #13 if 'FreiHAND' not in trainset_3d + trainset_2d + [testset] else 25 + lr = 1e-4 + lr_backbone = 1e-4 + lr_dec_factor = 10 + train_batch_size = 32 + use_gt_info = True + + ## testing config + test_batch_size = 32 + crowd = False + vis = False + render = False + use_bbox_in_ann = True + ## others + num_thread = 0 # 0 if use windows + gpu_ids = '0' + num_gpus = 1 + continue_train = False + finetune = False + + ## directory + cur_dir = osp.dirname(os.path.abspath(__file__)) + root_dir = osp.join(cur_dir, '..') + #root_dir = 'F:/full-head-dataset/skeleton_estimation/3DCrowdNet_RELEASE' + data_dir = osp.join(root_dir, 'data') + output_dir = osp.join(root_dir, 'output') + # hongsuk choi style + # KST = datetime.timezone(datetime.timedelta(hours=9)) + # save_folder = 'exp_' + str(datetime.datetime.now(tz=KST))[5:-16] + save_folder = 'exp_' + str(datetime.datetime.now())[5:-10].replace(':','-') + save_folder = save_folder.replace(" ", "_") + output_dir = osp.join(output_dir, save_folder) + print('output dir: ', output_dir) + + model_dir = osp.join(output_dir, 'checkpoint') + vis_dir = osp.join(output_dir, 'vis') + log_dir = osp.join(output_dir, 'log') + result_dir = osp.join(output_dir, 'result') + mano_path = osp.join(root_dir, 'common', 'utils', 'manopth') + smpl_path = osp.join(root_dir, 'common', 'utils', 'smplpytorch') + human_model_path = osp.join(root_dir, 'common', 'utils', 'human_model_files') + + def set_args(self, gpu_ids, continue_train=False, is_test=False, exp_dir=''): + print('exp_dir: ', exp_dir) + self.gpu_ids = gpu_ids + self.num_gpus = len(self.gpu_ids.split(',')) + self.bbox_3d_size = 2 + self.camera_3d_size = 2.5 + + if not is_test: + self.continue_train = continue_train + if self.continue_train: + start = time.time() + if exp_dir: + checkpoints = sorted(glob.glob(osp.join(exp_dir, 'checkpoint') + '/*.pth.tar'), + key=lambda x: int(x.split('_')[-1][:-8])) + shutil.copyfile(checkpoints[-1], osp.join(cfg.model_dir, os.path.basename(checkpoints[-1]))) + else: + shutil.copyfile(osp.join(cfg.root_dir, 'tool', 'snapshot_0.pth.tar'), + osp.join(cfg.model_dir, 'snapshot_0.pth.tar')) + + print('>>> Copying checkpoint file from {} to {} takes {:.2f} seconds'.format(checkpoints[-1], + osp.join(cfg.model_dir, + os.path.basename( + checkpoints[ + -1])), + time.time() - start)) + + elif is_test and exp_dir: + self.output_dir = exp_dir + self.model_dir = osp.join(self.output_dir, 'checkpoint') + self.vis_dir = osp.join(self.output_dir, 'vis') + self.log_dir = osp.join(self.output_dir, 'log') + self.result_dir = osp.join(self.output_dir, 'result') + + os.environ["CUDA_VISIBLE_DEVICES"] = self.gpu_ids + print('>>> Using GPU: {}'.format(self.gpu_ids)) + + if self.testset == 'FreiHAND': + assert self.trainset_3d[0] == 'FreiHAND' + assert len(self.trainset_3d) == 1 + assert len(self.trainset_2d) == 0 + + def set_data_dir(self,dir): + self.data_dir = dir + def update(self, config_file): + with open(config_file) as f: + exp_config = edict(yaml.load(f)) + for k, v in exp_config.items(): + if hasattr(cfg, k): + setattr(cfg, k, v) + else: + raise ValueError("{} not exist in config.py".format(k)) + +cfg = Config() + +sys.path.insert(0, osp.join(cfg.root_dir, 'common')) +from utils.dir import add_pypath, make_folder +add_pypath(osp.join(cfg.data_dir)) +dataset_list = ['CrowdPose', 'Human36M', 'MPII', 'MSCOCO', 'MuCo', 'PW3D'] +for i in range(len(dataset_list)): + add_pypath(osp.join(cfg.data_dir, dataset_list[i])) +make_folder(cfg.model_dir) +make_folder(cfg.vis_dir) +make_folder(cfg.log_dir) +make_folder(cfg.result_dir) diff --git a/data_processing/main/crop_images.py b/data_processing/main/crop_images.py new file mode 100644 index 0000000..6f20f62 --- /dev/null +++ b/data_processing/main/crop_images.py @@ -0,0 +1,45 @@ +import glob +import json +import os.path + +import cv2 +import sys +sys.path.append('../common') +import torch +import torch.nn as nn +from torch.nn import functional as F +from nets.resnet import ResNetBackbone +from nets.module import Pose2Feat, PositionNet, RotationNet, Vposer +from nets.loss import CoordLoss, ParamLoss, NormalVectorLoss, EdgeLengthLoss +from utils.smpl import SMPL +from utils.mano import MANO +from config import cfg +from contextlib import nullcontext +import math +# visualization +import colorsys +from utils.vis import vis_mesh, save_obj, render_mesh, vis_keypoints +import numpy as np +from utils.transforms import rot6d_to_axis_angle +import cv2 +from utils.preprocessing import generate_patch_image +with open('G:/full-head-dataset/pexels/00000000/result.json')as f: + result = json.load(f) + +for image_name in result: + bbox = result[image_name]['bbox'] + if (bbox[2] < 400 or bbox[3] < 400): + # os.remove(f'G:/full-head-dataset/pexels/00000000/visualization/{image_name}') + # if os.path.exists(f'G:/full-head-dataset/pexels/00000000/aligned_images/{image_name}'): + # os.remove(f'G:/full-head-dataset/pexels/00000000/aligned_images/{image_name}') + continue + if not os.path.exists(f'G:/full-head-dataset/pexels/00000000/visualization/{image_name}'): + continue + if os.path.exists(f'G:/full-head-dataset/pexels/00000000/aligned_images/{image_name}'): + continue + raw_image_name = image_name.split('_')[0] + image_path = glob.glob(f'G:/full-head-dataset/pexels/00000000/images/{raw_image_name}' + '*')[0] + print(image_path) + + img, _, _ = generate_patch_image(cv2.imread(image_path), bbox, 1.0, 0.0, False, (1024,1024),enable_padding=True) + cv2.imwrite(f'G:/full-head-dataset/pexels/00000000/aligned_images/{image_name}', img) \ No newline at end of file diff --git a/data_processing/main/get_theta_and_phi.py b/data_processing/main/get_theta_and_phi.py new file mode 100644 index 0000000..daddb1f --- /dev/null +++ b/data_processing/main/get_theta_and_phi.py @@ -0,0 +1,85 @@ +import math +import json +import os.path +import matplotlib.pyplot as plt +import cv2 +import numpy as np +import glob +import random + +def cartesian_to_spherical(x, y, z): + r = math.sqrt(x**2 + y**2 + z**2) + # theta = math.atan2(y, x) + # phi = math.acos(z / r) + # return r, theta, phi + theta = math.atan2(z, x) # 0~2pi + phi = math.acos(y / r) # 0~pi + return r, theta, phi + +thetas = [] +phis = [] +thetas_imgs = [] +phis_imgs = [] + +stride = 2 +stride_rad = stride / 180 * math.pi + +def flip_yaw(pose_matrix): + flipped = pose_matrix.copy() + flipped[0, 1] *= -1 + flipped[0, 2] *= -1 + flipped[1, 0] *= -1 + flipped[2, 0] *= -1 + flipped[0, 3] *= -1 + return flipped + + +for i in range(180//stride): + phis_imgs.append([]) +for i in range(360//stride): + thetas_imgs.append([]) +for i in range(0,5): + path = f'G:/full-head-dataset/pexels/{i * 1000:08d}' + image_list = glob.glob(f'{path}/aligned_images/*') + result_json_path = os.path.join(path, 'result.json') + with open(result_json_path, 'r') as f: + result = json.load(f) + + + + + for aligned_image_path in image_list: + aligned_image_name = os.path.basename(aligned_image_path) + + camera_pose = result[aligned_image_name]['camera_pose'] + camera_pose = np.reshape(camera_pose, (4, 4)) + #radius = np.linalg.norm(camera_pose[:3,3]) + _, theta, phi = cartesian_to_spherical(camera_pose[0,3], camera_pose[1,3], camera_pose[2,3]) + + thetas.append(theta) + phis.append(phi) + + flip_camerapose_in_pyrender = np.array(result[aligned_image_name]['normalized_camerapose_in_pyrender']) + flip_camerapose_in_pyrender = flip_yaw(flip_camerapose_in_pyrender) + flip_world2camera_matrix = np.linalg.inv(flip_camerapose_in_pyrender) + flip_world2camera_matrix[[1, 2]] *= -1 + camera_pose = np.linalg.inv(flip_world2camera_matrix) + _, theta, phi = cartesian_to_spherical(camera_pose[0, 3], camera_pose[1, 3], camera_pose[2, 3]) + + thetas.append(theta) + phis.append(phi) + + + +plt.scatter(thetas, phis) +plt.show() + +# if abs(theta - np.pi/2) < 0.1: +# phi_bin = int(phi/stride_rad) +# phis_imgs[phi_bin].append(aligned_image_path) +# +# count = 0 +# for i in range(len(phis_imgs)): +# if len(phis_imgs[i]) > 0: +# cv2.imwrite(f'G:/full-head-dataset/pexels/theta_phi/{count}.png', cv2.imread(random.choice(phis_imgs[i]))) +# count+=1 \ No newline at end of file diff --git a/data_processing/main/model.py b/data_processing/main/model.py new file mode 100644 index 0000000..1c984a8 --- /dev/null +++ b/data_processing/main/model.py @@ -0,0 +1,634 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from nets.resnet import ResNetBackbone +from nets.module import Pose2Feat, PositionNet, RotationNet, Vposer +from nets.loss import CoordLoss, ParamLoss, NormalVectorLoss, EdgeLengthLoss +from utils.smpl import SMPL +from utils.mano import MANO +from config import cfg +from contextlib import nullcontext +import math +# visualization +import colorsys +from utils.vis import vis_mesh, save_obj, render_mesh, vis_keypoints +import numpy as np +from utils.transforms import rot6d_to_axis_angle +import cv2 +from utils.preprocessing import generate_patch_image + + +class Model(nn.Module): + def __init__(self, backbone, pose2feat, position_net, rotation_net, vposer): + super(Model, self).__init__() + self.backbone = backbone + self.pose2feat = pose2feat + self.position_net = position_net + self.rotation_net = rotation_net + self.vposer = vposer + + if 'FreiHAND' in cfg.trainset_3d + cfg.trainset_2d + [cfg.testset]: + self.human_model = MANO() + self.human_model_layer = self.human_model.layer.cuda() + else: + self.human_model = SMPL() + self.human_model_layer = self.human_model.layer['neutral'].cuda() + self.root_joint_idx = self.human_model.root_joint_idx + self.mesh_face = self.human_model.face + self.joint_regressor = self.human_model.joint_regressor + + self.coord_loss = CoordLoss() + self.param_loss = ParamLoss() + + # The joint that we want to align to the origin + self.align_joint_name = 'Neck' + # 0.0649 is the height between the neck joint and head joint of the template + self.init_camera_location = torch.tensor([0, 0.0649, 2.7]).float().cuda() + + # get template mesh + root_pose = torch.zeros((1, 3)).cuda() + pose_param = torch.zeros((1, 69)).cuda() + cam_trans = torch.zeros((1, 3)).cuda() + shape_param = torch.zeros((1, 10)).cuda() + pose_param = pose_param.view(-1, self.human_model.orig_joint_num - 1, 3) + pose_param = torch.cat((root_pose[:, None, :], pose_param), 1).view(-1, self.human_model.orig_joint_num * 3) + coord_output = self.get_coord(pose_param, shape_param, cam_trans) + self.template_mesh_cam_render = coord_output['mesh_cam_render'] + + # align neck joint to origin + template_align_joint_coorinate = coord_output['align_joint_coorinate'] # 1 x 1 x 3 + # print('template_align_joint_coorinate:',template_align_joint_coorinate) + # exit() + self.template_mesh_cam_render -= template_align_joint_coorinate # 1 x 6890 x 3 + self.template_align_joint_coorinate = template_align_joint_coorinate + + # used for real world rendering, should not rotate + self.template_mesh_cam_render_no_flip = self.template_mesh_cam_render.clone() + + self.template_mesh_cam_render_no_flip_joint = torch.bmm( + torch.from_numpy(self.joint_regressor).cuda()[None, :, :].repeat(1, 1, 1), + self.template_mesh_cam_render_no_flip) + + # in pyrender, should rotate 180 degree around x axis (since y and z axis are flipped) + R = torch.eye(4).cuda() + angle = torch.FloatTensor([np.pi]).cuda() + R[1, 1] = torch.cos(angle) + R[1, 2] = -torch.sin(angle) + R[2, 1] = torch.sin(angle) + R[2, 2] = torch.cos(angle) + + self.template_mesh_R = R + + self.template_mesh_cam_render = torch.matmul(self.template_mesh_cam_render, R[:3, :3]) + + x_axis_ = np.array([1, 0, 0]) + y_axis_ = np.array([0, 1, 0]) + z_axis_ = np.array([0, 0, -1]) + self.Axis_original = np.concatenate([x_axis_[:, None], y_axis_[:, None], z_axis_[:, None]], axis=1) + + self.min_box_stride = None + + def get_neck_head_rotated_template_mesh(self, pose_params_input): + root_pose = torch.zeros((1, 3)).cuda() + pose_param = torch.zeros((1, 69)).cuda() + cam_trans = torch.zeros((1, 3)).cuda() + shape_param = torch.zeros((1, 10)).cuda() + pose_param = pose_param.view(-1, self.human_model.orig_joint_num - 1, 3) + + pose_param[:, [11, 14], :] = pose_params_input[:, [11, 14], :] + + pose_param = torch.cat((root_pose[:, None, :], pose_param), 1).view(-1, self.human_model.orig_joint_num * 3) + coord_output = self.get_coord(pose_param, shape_param, cam_trans) + mesh_cam_render = coord_output['mesh_cam_render'] + mesh_cam_render -= self.template_align_joint_coorinate + mesh_cam_render = torch.matmul(mesh_cam_render, self.template_mesh_R[:3, :3]) + + return mesh_cam_render + + def get_neck_head_rotated_template_mesh_joint(self, pose_params_input): + root_pose = torch.zeros((1, 3)).cuda() + pose_param = torch.zeros((1, 69)).cuda() + cam_trans = torch.zeros((1, 3)).cuda() + shape_param = torch.zeros((1, 10)).cuda() + pose_param = pose_param.view(-1, self.human_model.orig_joint_num - 1, 3) + + pose_param[:, [11, 14], :] = pose_params_input[:, [11, 14], :] + + pose_param = torch.cat((root_pose[:, None, :], pose_param), 1).view(-1, self.human_model.orig_joint_num * 3) + coord_output = self.get_coord(pose_param, shape_param, cam_trans) + joints_3d = coord_output['joints_3d'] + joints_3d -= self.template_align_joint_coorinate + + return joints_3d + + def set_min_box_stride(self, min_box_stride): + self.min_box_stride = min_box_stride + + def compute_shoulder_points_R(self, mesh_a, mesh_b): + ''' + :param mesh_a: 1 x 6890 x 3 + :param mesh_b: 1 x 6890 x 3 + + shoulder_vertex_index: 55, + ''' + + joints_a = torch.bmm(torch.from_numpy(self.joint_regressor).cuda()[None, :, :].repeat(mesh_a.shape[0], 1, 1), + mesh_a) + joints_b = torch.bmm(torch.from_numpy(self.joint_regressor).cuda()[None, :, :].repeat(mesh_b.shape[0], 1, 1), + mesh_b) + + selected_joints = [ + 'L_Shoulder', 'R_Shoulder', + # 'L_Thorax', 'R_Thorax', + 'Neck', + # 'Chest', + 'Pelvis' + ] + selected_joints_index = [self.human_model.joints_name.index(joints_name) for joints_name in selected_joints] + + points_a = joints_a[:, selected_joints_index, :] + points_b = joints_b[:, selected_joints_index, :] + + A = points_a[0, :, :].cpu().numpy() # 55 x 3 + B = points_b[0, :, :].cpu().numpy() # 55 x 3 + mean_A = np.mean(A, axis=0, keepdims=True) + mean_B = np.mean(B, axis=0, keepdims=True) + + A = A - mean_A + B = B - mean_B + + H = np.transpose(A) @ B + + U, S, Vt = np.linalg.svd(H) + R = Vt.T @ U.T + + if np.linalg.det(R) < 0: + Vt[2, :] *= -1 + R = Vt.T @ U.T + + return torch.from_numpy(R).cuda().float() # 3 x 3 + + def get_camera_trans(self, cam_param, bbox, is_render): + # camera translation + t_xy = cam_param[:, :2] + gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive + k_value = torch.FloatTensor([math.sqrt(cfg.focal[0] * cfg.focal[1] * cfg.camera_3d_size * cfg.camera_3d_size / ( + cfg.input_img_shape[0] * cfg.input_img_shape[1]))]).cuda().view(-1) + if is_render: + k_value = k_value * math.sqrt(cfg.input_img_shape[0] * cfg.input_img_shape[1]) / ( + bbox[:, 2] * bbox[:, 3]).sqrt() + t_z = k_value * gamma + cam_trans = torch.cat((t_xy, t_z[:, None]), 1) + return cam_trans + + def make_2d_gaussian_heatmap(self, joint_coord_img): + x = torch.arange(cfg.output_hm_shape[2]) + y = torch.arange(cfg.output_hm_shape[1]) + yy, xx = torch.meshgrid(y, x) + xx = xx[None, None, :, :].cuda().float(); + yy = yy[None, None, :, :].cuda().float(); + + x = joint_coord_img[:, :, 0, None, None]; + y = joint_coord_img[:, :, 1, None, None]; + heatmap = torch.exp( + -(((xx - x) / cfg.sigma) ** 2) / 2 - (((yy - y) / cfg.sigma) ** 2) / 2) + return heatmap + + def get_coord(self, smpl_pose, smpl_shape, smpl_trans): + batch_size = smpl_pose.shape[0] + mesh_cam, mesh_joints = self.human_model_layer(smpl_pose, smpl_shape, smpl_trans) + # camera-centered 3D coordinate + joint_cam = torch.bmm(torch.from_numpy(self.joint_regressor).cuda()[None, :, :].repeat(batch_size, 1, 1), + mesh_cam) + joints_3d = joint_cam.clone() + # head + align_joint_coorinate = joint_cam[:, self.human_model.joints_name.index(self.align_joint_name), None, :] + + root_joint_idx = self.human_model.root_joint_idx + + # project 3D coordinates to 2D space + x = joint_cam[:, :, 0] / (joint_cam[:, :, 2] + 1e-4) * cfg.focal[0] + cfg.princpt[0] + y = joint_cam[:, :, 1] / (joint_cam[:, :, 2] + 1e-4) * cfg.focal[1] + cfg.princpt[1] + x = x / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + y = y / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + joint_proj = torch.stack((x, y), 2) + + mesh_cam_render = mesh_cam.clone() + # root-relative 3D coordinates + root_cam = joint_cam[:, root_joint_idx, None, :] + joint_cam = joint_cam - root_cam + mesh_cam = mesh_cam - root_cam + return { + 'joint_proj': joint_proj, + 'joint_cam': joint_cam, + 'mesh_cam': mesh_cam, + 'mesh_cam_render': mesh_cam_render, + 'align_joint_coorinate': align_joint_coorinate, + 'root_cam': root_cam, + 'joints_3d': joints_3d + } + + def generate_visualization(self, image, mesh_cam_render, joint): + + # princpt = (bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2) + # generate random color + color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + # bbox = out['bbox'][0].cpu().numpy() + + mesh_image = render_mesh(image.copy() * 255, mesh_cam_render, self.human_model.face, + {'focal': cfg.focal, 'princpt': cfg.princpt}, color=color) + + joint_image = vis_keypoints(image.copy() * 255, joint) + + viz = np.concatenate([image.copy() * 255, + joint_image.astype(np.uint8), + mesh_image.astype(np.uint8)], + axis=1)[:, :, ::-1] + return viz + + def get_visualization(self, inputs, targets, meta_info): + inputs = inputs + for key in inputs: + inputs[key] = inputs[key].cuda() + output = self.forward(inputs, targets, meta_info, mode='test') + viz_predicts = [] + for i in range(inputs['img'].shape[0]): + viz_predict = self.generate_visualization(image=inputs['img'][i].cpu().numpy().transpose(1, 2, 0), + mesh_cam_render=output['mesh_cam_render'][ + i].detach().cpu().numpy(), + joint=inputs['joints'][i].detach().cpu().numpy() * ( + cfg.input_img_shape[1] / cfg.output_hm_shape[2]) + ) + viz_predicts.append(viz_predict) + + return viz_predicts + + def forward(self, inputs, targets, meta_info, mode): + early_img_feat = self.backbone(inputs['img']) # pose_guided_img_feat + + # get pose gauided image feature + joint_coord_img = inputs['joints'] + with torch.no_grad(): + joint_heatmap = self.make_2d_gaussian_heatmap(joint_coord_img.detach()) + # remove blob centered at (0,0) == invalid ones + joint_heatmap = joint_heatmap * inputs['joints_mask'][:, :, :, None] + pose_img_feat = self.pose2feat(early_img_feat, joint_heatmap) + pose_guided_img_feat = self.backbone(pose_img_feat, skip_early=True) # 2048 x 8 x 8 + + joint_img, joint_score = self.position_net(pose_guided_img_feat) # refined 2D pose or 3D pose + + # estimate model parameters + root_pose_6d, z, shape_param, cam_param = self.rotation_net(pose_guided_img_feat, joint_img.detach(), + joint_score.detach()) + # change root pose 6d + latent code -> axis angles + root_pose = rot6d_to_axis_angle(root_pose_6d) + pose_param = self.vposer(z) + cam_trans = self.get_camera_trans(cam_param, meta_info['bbox'], is_render=(cfg.render and (mode == 'test'))) + pose_param = pose_param.view(-1, self.human_model.orig_joint_num - 1, 3) + + body_pose_param = pose_param.clone() + + pose_param = torch.cat((root_pose[:, None, :], pose_param), 1).view(-1, self.human_model.orig_joint_num * 3) + coord_output = self.get_coord(pose_param, shape_param, cam_trans) + joint_proj, joint_cam, mesh_cam, mesh_cam_render = coord_output['joint_proj'], coord_output['joint_cam'], \ + coord_output['mesh_cam'], coord_output['mesh_cam_render'] + + if mode == 'train': + # loss functions + loss = {} + # joint_img: 0~8, joint_proj: 0~64, target: 0~64 + loss['body_joint_img'] = (1 / 8) * self.coord_loss(joint_img * 8, self.human_model.reduce_joint_set( + targets['orig_joint_img']), self.human_model.reduce_joint_set(meta_info['orig_joint_trunc']), + meta_info['is_3D']) + loss['smpl_joint_img'] = (1 / 8) * self.coord_loss(joint_img * 8, self.human_model.reduce_joint_set( + targets['fit_joint_img']), + self.human_model.reduce_joint_set( + meta_info['fit_joint_trunc']) * meta_info[ + 'is_valid_fit'][ + :, None, None]) + loss['smpl_pose'] = self.param_loss(pose_param, targets['pose_param'], + meta_info['fit_param_valid'] * meta_info['is_valid_fit'][:, None]) + loss['smpl_shape'] = self.param_loss(shape_param, targets['shape_param'], + meta_info['is_valid_fit'][:, None]) + loss['body_joint_proj'] = (1 / 8) * self.coord_loss(joint_proj, targets['orig_joint_img'][:, :, :2], + meta_info['orig_joint_trunc']) + loss['body_joint_cam'] = self.coord_loss(joint_cam, targets['orig_joint_cam'], + meta_info['orig_joint_valid'] * meta_info['is_3D'][:, None, None]) + loss['smpl_joint_cam'] = self.coord_loss(joint_cam, targets['fit_joint_cam'], + meta_info['is_valid_fit'][:, None, None]) + + return loss + + else: + # test output + out = {'cam_param': cam_param} + # out['input_joints'] = joint_coord_img + out['joint_img'] = joint_img * 8 + out['joint_proj'] = joint_proj + out['joint_score'] = joint_score + out['smpl_mesh_cam'] = mesh_cam + out['smpl_pose'] = pose_param.clone() + out['smpl_shape'] = shape_param.clone() + out['cam_trans'] = cam_trans.clone() + + out['mesh_cam_render'] = mesh_cam_render + out['mesh_cam_render_joints_3d'] = coord_output['joints_3d'] + + if 'smpl_mesh_cam' in targets: + out['smpl_mesh_cam_target'] = targets['smpl_mesh_cam'] + if 'bb2img_trans' in meta_info: + out['bb2img_trans'] = meta_info['bb2img_trans'] + if 'img2bb_trans' in meta_info: + out['img2bb_trans'] = meta_info['img2bb_trans'] + if 'bbox' in meta_info: + out['bbox'] = meta_info['bbox'] + if 'tight_bbox' in meta_info: + out['tight_bbox'] = meta_info['tight_bbox'] + if 'aid' in meta_info: + out['aid'] = meta_info['aid'] + + out['neck_head_rotated_template_mesh'] = self.get_neck_head_rotated_template_mesh(body_pose_param) + + cam_trans_crop = self.get_camera_trans(cam_param, meta_info['bbox'], is_render=False) + coord_output_crop = self.get_coord(pose_param, shape_param, cam_trans_crop) + mesh_cam_render_crop = coord_output_crop['mesh_cam_render'] + out['mesh_cam_render_crop'] = mesh_cam_render_crop + out['align_joint_coorinate_crop'] = coord_output_crop['align_joint_coorinate'] + # align neck joint to origin + align_joint_coorinate = coord_output['align_joint_coorinate'] # 1 x 1 x 3 + mesh_cam_render_aligned = mesh_cam_render.clone() # 1 x 6890 x 3 + # align neck joint to origin + mesh_cam_render_aligned -= align_joint_coorinate + out['mesh_cam_render_aligned'] = mesh_cam_render_aligned + out['align_joint_coorinate'] = align_joint_coorinate + + # ======================translation =================== + translation_in_pyrender = torch.eye(4, device=mesh_cam_render_aligned.device) + translation_in_pyrender[:3, 3:4] = -align_joint_coorinate.squeeze(1).T + + # flip y axis and z axis to render in pyrender correctly + translation_in_pyrender[[1, 2], 3] *= -1 + # ===================================================== + + # ======================rotaion ======================= + rotation_in_pyrender = torch.eye(4, device=mesh_cam_render_aligned.device) + # compute the rotation matrix that rotate template to the aligned mesh + R = self.compute_shoulder_points_R(self.template_mesh_cam_render, mesh_cam_render_aligned) + # flip y axis and z axis to render in pyrender correctly + angles = cv2.Rodrigues(torch.inverse(R).cpu().numpy())[0] + angles[[1, 2], :] *= -1 + R_in_pyrender = cv2.Rodrigues(angles)[0] + rotation_in_pyrender[:3, :3] = torch.from_numpy(R_in_pyrender).to(mesh_cam_render_aligned.device) + # ======================================================== + + # ========================== remder template on original image ================================== + out['camera_pose_in_pyrender'] = rotation_in_pyrender @ translation_in_pyrender + out['camera_to_render_template_in_pyrender'] = translation_in_pyrender + out['no_rotation_world2camera_transformation_in_real_world'] = torch.inverse(translation_in_pyrender) # + out['no_rotation_world2camera_transformation_in_real_world'][[1, 2]] *= -1 + + # ========================== Normalized camera ========================== + normalized_camerapose_in_pyrender = out['camera_pose_in_pyrender'].cpu().numpy() + + camera_position = normalized_camerapose_in_pyrender[:3, 3] + camera_position = camera_position / np.linalg.norm(camera_position) * 2.7 + + camera_up = normalized_camerapose_in_pyrender[:3, :3] @ np.reshape(np.array([0, 1, 0]), (3, 1))[:, 0] # 3, + + # we suppose the camera is always looking at the [0, 0.0649, 0] + Lookat = np.array([0, 0.0649, 0]) + + z_axis = Lookat - camera_position + z_axis = z_axis / np.linalg.norm(z_axis) + x_axis = -np.cross(camera_up, z_axis) + x_axis = x_axis / np.linalg.norm(x_axis) + y_axis = -np.cross(z_axis, x_axis) + y_axis = y_axis / np.linalg.norm(y_axis) + Axis_new = np.concatenate([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) + + R = Axis_new @ np.linalg.inv(self.Axis_original) + + normalized_camerapose_in_pyrender[:3, :3] = R + normalized_camerapose_in_pyrender[:3, 3] = camera_position + + normalized_transformation_in_realworld = np.linalg.inv(normalized_camerapose_in_pyrender) + normalized_transformation_in_realworld[[1, 2]] *= -1 + + out['normalized_camerapose_in_pyrender'] = normalized_camerapose_in_pyrender + out['normalized_transformation_in_realworld'] = normalized_transformation_in_realworld + + return out + + def get_projected_joints(self, joint_cam, world_2_camera_matrix, princpt): + # joint in 3d world coordinate + joint_cam = joint_cam.squeeze(0) # 30 x 3 + joint_cam = torch.cat([joint_cam, torch.ones(joint_cam.shape[0], 1).cuda()], dim=1) # 30 x 4 + joint_on_input_image = world_2_camera_matrix @ joint_cam.T # 4,30 + joint_on_input_image = joint_on_input_image[:3, :].cpu().numpy() + + intrinsic_matrix = np.eye(3) + intrinsic_matrix[0, 0] = cfg.focal[0] + intrinsic_matrix[1, 1] = cfg.focal[1] + intrinsic_matrix[0, 2] = princpt[0] + intrinsic_matrix[1, 2] = princpt[1] + + joint_on_input_image = intrinsic_matrix @ joint_on_input_image + joint_on_input_image = joint_on_input_image / joint_on_input_image[2, :] + joint_on_input_image = joint_on_input_image[:2, :].T # 30,2 + + return joint_on_input_image + + def get_projected_vertex(self, mesh, world2screen_matrix): + + mesh = mesh.squeeze(0) + mesh = torch.cat([mesh, torch.ones(mesh.shape[0], 1).cuda()], dim=1).cpu().numpy() # 6890 x 4 + points_image = world2screen_matrix @ mesh.T # 4,6890 + points_image = points_image[:3, :] # 3,6890 + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + + def __crop_and_process_camera_matrix__(self, model_output, input_image, joint_2d, crop_image_size, model_input_bbox, + head_bbox, use_head_bbox): + # project template mesh to input full size image + template_joint_on_input_image = self.get_projected_joints(self.template_mesh_cam_render_no_flip_joint, + model_output[ + 'no_rotation_world2camera_transformation_in_real_world'], + (model_input_bbox[0] + model_input_bbox[2] / 2, + model_input_bbox[1] + model_input_bbox[3] / 2)) + + L_Shoulder_2d = template_joint_on_input_image[self.human_model.joints_name.index('L_Shoulder'), :] + R_Shoulder_2d = template_joint_on_input_image[self.human_model.joints_name.index('R_Shoulder'), :] + + # project template mesh using the nomalized camera (1024) + template_joint_on_crop_image = self.get_projected_joints( + self.template_mesh_cam_render_no_flip_joint, + torch.from_numpy(model_output['normalized_transformation_in_realworld']).float().cuda(), + (crop_image_size / 2, crop_image_size / 2)) + L_Shoulder_2d_on_crop_image = template_joint_on_crop_image[self.human_model.joints_name.index('L_Shoulder'), :] + R_Shoulder_2d_on_crop_image = template_joint_on_crop_image[self.human_model.joints_name.index('R_Shoulder'), :] + Shoulder_center_on_crop_image = (L_Shoulder_2d_on_crop_image + R_Shoulder_2d_on_crop_image) / 2.0 + + # vis = crop_output['cropped_image'].copy() + # for i in range(template_joint_on_crop_image.shape[0]): + # cv2.circle(vis, (int(template_joint_on_crop_image[i, 0]), int(template_joint_on_crop_image[i, 1])), 5, + # (0, 255, 255), -1) + # cv2.putText(vis, str(i), (int(template_joint_on_crop_image[i, 0]), int(template_joint_on_crop_image[i, 1])), + # cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 1) + # + # vis = cv2.resize(vis, (vis.shape[1] // 4, vis.shape[0] // 4)) + # cv2.imshow('input_image', vis.astype(np.uint8)) + # cv2.waitKey(0) + # exit() + + L_Shoulder_coco = joint_2d[5, :2] + R_Shoulder_coco = joint_2d[6, :2] + shoulder_center_coco = (L_Shoulder_coco + R_Shoulder_coco) / 2.0 + + ''' + cv2.circle(input_image, (int(L_ear_from_coco[ 0]), int(L_ear_from_coco[ 1])), 10, (0, 0, 255), -1) + cv2.putText(input_image, "L_ear_from_coco", (int(L_ear_from_coco[0]), int(L_ear_from_coco[1])), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 1) + + cv2.circle(input_image, (int(R_ear_from_coco[0]), int(R_ear_from_coco[1])), 10, (0, 0, 255), -1) + cv2.putText(input_image, "R_ear_from_coco", (int(R_ear_from_coco[0]), int(R_ear_from_coco[1])), + cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 1) + + cv2.circle(input_image, (int(rotated_L_ear_2d[0]), int(rotated_L_ear_2d[1])), 10, (0, 255, 255), -1) + cv2.putText(input_image, "rotated_L_ear_2d", (int(rotated_L_ear_2d[0]), int(rotated_L_ear_2d[1])), + cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 255), 1) + + cv2.circle(input_image, (int(rotated_R_ear_2d[0]), int(rotated_R_ear_2d[1])), 10, (0, 255, 255), -1) + cv2.putText(input_image, "rotated_R_ear_2d", (int(rotated_R_ear_2d[0]), int(rotated_R_ear_2d[1])), + cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 255), 1) + + + input_image = cv2.resize(input_image, (input_image.shape[1] // 4, input_image.shape[0] // 4)) + cv2.imshow('input_image', input_image) + cv2.waitKey(0) + #''' + + # compute the stride of the bbox, using the shoulder distance of head stride + if use_head_bbox: + assert len(head_bbox) == 4 + head_stride = min(head_bbox[2], head_bbox[3]) + bbox_stride = head_stride * 2.6 + else: + shoulder_stride = np.linalg.norm(L_Shoulder_2d - R_Shoulder_2d) + bbox_stride = shoulder_stride * 1.6 + + # shoulder_center_coco should be aligned with Shoulder_center_on_crop_image + # bbox_center = shoulder_center_coco - (Shoulder_center_on_crop_image - crop_image_size / 2.0)/crop_image_size*bbox_stride #aligned_joint_2d + align_translation_on_input_image + # (shoulder_center_coco - bbox_center)/bbox_stride*crop_image_size = (Shoulder_center_on_crop_image - crop_image_size / 2.0) + # (shoulder_center_coco - bbox_center) = (Shoulder_center_on_crop_image - crop_image_size / 2.0)*bbox_stride/crop_image_size + + bbox_center = shoulder_center_coco - ( + Shoulder_center_on_crop_image - crop_image_size / 2.0) * bbox_stride / crop_image_size + + bbox = np.array([bbox_center[0] - bbox_stride / 2, bbox_center[1] - bbox_stride / 2, bbox_stride, bbox_stride]) + # print('bbox',bbox) + + if bbox[2] < self.min_box_stride or bbox[3] < self.min_box_stride: + return None + + # crop_image + try: + img, img2bb_trans, bb2img_trans = generate_patch_image(input_image, bbox, 1.0, 0.0, False, + (crop_image_size, crop_image_size), + enable_padding=True) + except: + return None + + # the cam_param is corresponding to the original bbox + original_cam_param = {'focal': cfg.focal, 'princpt': ( + model_input_bbox[0] + model_input_bbox[2] / 2, model_input_bbox[1] + model_input_bbox[3] / 2)} + + # rescale to the original image size + + # crop to new bbox + w_crop = bbox[0] + h_crop = bbox[1] + + cx_new = original_cam_param['princpt'][0] - w_crop + cy_new = original_cam_param['princpt'][1] - h_crop + + translated_princpt = (cx_new, cy_new) + + # rescale to the crop image + new_focal = (cfg.focal[0] / bbox[2] * crop_image_size, cfg.focal[1] / bbox[3] * crop_image_size) + new_princpt = ( + translated_princpt[0] / bbox[2] * crop_image_size, translated_princpt[1] / bbox[3] * crop_image_size) + + cam_param = {'focal': new_focal, 'princpt': new_princpt} + + out = {} + out['intrisics'] = cam_param + out['cropped_image'] = img + out['bbox'] = bbox + out['bbox_stride'] = bbox_stride + + return out + + def crop_and_process_camera_matrix(self, model_output, input_image, joint_2d, crop_image_size, model_input_bbox, + head_bbox): + + out = [] + + head_bbox_score = head_bbox['score'] + head_bbox_ = head_bbox['bbox'] + + if len(head_bbox_) == 4: + out_ = self.__crop_and_process_camera_matrix__(model_output, input_image, joint_2d, crop_image_size, + model_input_bbox, + head_bbox_, use_head_bbox=True) + if out_ is not None: + out.append(out_) + # out_ = self.__crop_and_process_camera_matrix__(model_output, input_image, joint_2d, crop_image_size, + # model_input_bbox, + # head_bbox_, use_head_bbox=False) + # if out_ is not None and len(out) > 0: + # if abs(out_['bbox_stride'] - out[0]['bbox_stride']) > out[0]['bbox_stride'] * 0.05: + # out.append(out_) + + else: + # no bbox, use the shoulder as stride + out_ = self.__crop_and_process_camera_matrix__(model_output, input_image, joint_2d, crop_image_size, + model_input_bbox, + head_bbox_, use_head_bbox=False) + if out_ is not None: + out.append(out_) + + return out + + +def init_weights(m): + if type(m) == nn.ConvTranspose2d: + nn.init.normal_(m.weight, std=0.001) + elif type(m) == nn.Conv2d: + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0) + elif type(m) == nn.BatchNorm2d: + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif type(m) == nn.Linear: + nn.init.normal_(m.weight, std=0.01) + nn.init.constant_(m.bias, 0) + + +def get_model(vertex_num, joint_num, mode): + backbone = ResNetBackbone(cfg.resnet_type) + pose2feat = Pose2Feat(joint_num) + position_net = PositionNet() + rotation_net = RotationNet() + vposer = Vposer() + + if mode == 'train': + backbone.init_weights() + pose2feat.apply(init_weights) + position_net.apply(init_weights) + rotation_net.apply(init_weights) + + model = Model(backbone, pose2feat, position_net, rotation_net, vposer) + return model + diff --git a/data_processing/main/test.py b/data_processing/main/test.py new file mode 100644 index 0000000..633d427 --- /dev/null +++ b/data_processing/main/test.py @@ -0,0 +1,66 @@ +import os +# os.environ["PYOPENGL_PLATFORM"] = "osmesa" +import torch +import argparse +from tqdm import tqdm +import numpy as np +import torch.backends.cudnn as cudnn +from config import cfg +from base import Tester + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--test_epoch', type=str, dest='test_epoch') + parser.add_argument('--exp_dir', type=str, default='') + parser.add_argument('--cfg', type=str, default='', help='experiment configure file name') + + args = parser.parse_args() + + if not args.gpu_ids: + assert 0, "Please set propoer gpu ids" + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + assert args.test_epoch, 'Test epoch is required.' + return args + +def main(): + args = parse_args() + cfg.set_args(args.gpu_ids, is_test=True, exp_dir=args.exp_dir) + cudnn.benchmark = True + if args.cfg: + cfg.update(args.cfg) + + tester = Tester(args.test_epoch) + tester._make_batch_generator() + tester._make_model() + + eval_result = {} + cur_sample_idx = 0 + for itr, (inputs, targets, meta_info) in enumerate(tqdm(tester.batch_generator)): + + # forward + with torch.no_grad(): + out = tester.model(inputs, targets, meta_info, 'test') + + # save output + out = {k: v.cpu().numpy() for k,v in out.items()} + for k,v in out.items(): batch_size = out[k].shape[0] + out = [{k: v[bid] for k,v in out.items()} for bid in range(batch_size)] + + # evaluate + cur_eval_result = tester._evaluate(out, cur_sample_idx) + for k,v in cur_eval_result.items(): + if k in eval_result: eval_result[k] += v + else: eval_result[k] = v + cur_sample_idx += len(out) + + tester._print_eval_result(eval_result) + +if __name__ == "__main__": + main() diff --git a/data_processing/main/train.py b/data_processing/main/train.py new file mode 100644 index 0000000..84ec01a --- /dev/null +++ b/data_processing/main/train.py @@ -0,0 +1,109 @@ +import argparse +from config import cfg +import torch +from base import Trainer +import torch.backends.cudnn as cudnn +import torch.cuda.amp as amp +from torch.utils.tensorboard import SummaryWriter + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, dest='gpu_ids') + parser.add_argument('--continue', dest='continue_train', action='store_true') + parser.add_argument('--exp_dir', type=str, default='', help='for resuming train') + parser.add_argument('--amp', dest='use_mixed_precision', action='store_true', help='use automatic mixed precision training') + parser.add_argument('--init_scale', type=float, default=1024., help='initial loss scale') + parser.add_argument('--cfg', type=str, default='', help='experiment configure file name') + parser.add_argument('--data_dir', type=str, default='', help='dataset path') + + args = parser.parse_args() + + if not args.gpu_ids: + assert 0, "Please set propoer gpu ids" + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + + +def main(): + # argument parse and create log + args = parse_args() + cfg.set_args(args.gpu_ids, args.continue_train, exp_dir=args.exp_dir) + cudnn.benchmark = True + if args.cfg: + cfg.update(args.cfg) + if args.data_dir is not '': + cfg.set_data_dir(args.data_dir) + writer = SummaryWriter(cfg.log_dir) + trainer = Trainer() + trainer._make_batch_generator() + trainer._make_model() + + scaler = amp.GradScaler(init_scale=args.init_scale, enabled=args.use_mixed_precision) + global_step = 0 + # train + for epoch in range(trainer.start_epoch, cfg.end_epoch): + print('Epoch %d/%d' % (epoch, cfg.end_epoch)) + trainer.set_lr(epoch) + trainer.tot_timer.tic() + trainer.read_timer.tic() + for itr, (inputs, targets, meta_info) in enumerate(trainer.batch_generator): + trainer.read_timer.toc() + trainer.gpu_timer.tic() + + # forward + trainer.optimizer.zero_grad() + with amp.autocast(args.use_mixed_precision): + loss = trainer.model(inputs, targets, meta_info, 'train') + loss = {k: loss[k].mean() for k in loss} + for k in loss: + writer.add_scalar('train/loss_' + k, loss[k].detach(), global_step) + _loss = sum(loss[k] for k in loss) + + # backward + with amp.autocast(False): + _loss = scaler.scale(_loss) + _loss.backward() + scaler.step(trainer.optimizer) + + scaler.update(args.init_scale) + + trainer.gpu_timer.toc() + screen = [ + 'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch), + 'lr: %g' % (trainer.get_lr()), + 'speed: %.2f(%.2fs r%.2f)s/itr' % ( + trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time), + '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch), + ] + screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in loss.items()] + trainer.logger.info(' '.join(screen)) + + trainer.tot_timer.toc() + trainer.tot_timer.tic() + trainer.read_timer.tic() + + if itr % 400 == 0: + trainer.save_visualization(inputs, targets, meta_info, epoch,itr) + if itr % 1000 == 0: + trainer.save_model({ + 'epoch': epoch, + 'network': trainer.model.state_dict(), + 'optimizer': trainer.optimizer.state_dict(), + }, epoch,itr) + global_step += 1 + + # trainer.save_model({ + # 'epoch': epoch, + # 'network': trainer.model.state_dict(), + # 'optimizer': trainer.optimizer.state_dict(), + # }, epoch) + + +if __name__ == "__main__": + main() diff --git a/data_processing/prepare_data.py b/data_processing/prepare_data.py new file mode 100644 index 0000000..136a4b4 --- /dev/null +++ b/data_processing/prepare_data.py @@ -0,0 +1,23 @@ +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--input_dir', type=str, default='') +args = parser.parse_args() + +input_dir = args.input_dir + +root = os.path.dirname(os.path.abspath(__file__)) +print(root) + + +os.chdir(os.path.join(root,'HigherHRNet-Human-Pose-Estimation')) +command = f'python tools/get_keypoints.py --cfg experiments/coco/higher_hrnet/w32_512_adam_lr1e-3.yaml --input_dir {input_dir} TEST.MODEL_FILE models/pytorch/pose_coco/pose_higher_hrnet_w32_512.pth ' +print(command) +os.system(command) +# head-detection + +os.chdir(os.path.join(root,'yolov5_crowdhuman')) +command = f'python detect_head_bbox.py --weights crowdhuman_yolov5m.pt --source {input_dir} --heads' +print(command) +os.system(command) diff --git a/data_processing/preprocess_img_for_inversion.py b/data_processing/preprocess_img_for_inversion.py new file mode 100644 index 0000000..d5ed43d --- /dev/null +++ b/data_processing/preprocess_img_for_inversion.py @@ -0,0 +1,103 @@ +import glob + +import os +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('--test_data_dir', type=str,default='../test_data') + + +opt = parser.parse_args() +test_data_dir = opt.test_data_dir + + +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + samples_dir = os.path.join(sub_dir, 'samples') + + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + if os.path.exists(new_crop_samples_dir): + continue + images_dir = os.path.join(samples_dir, 'images') + os.makedirs(images_dir, exist_ok=True) + for image in glob.glob(os.path.join(samples_dir, '*.png')): + os.rename(image, os.path.join(images_dir, os.path.basename(image))) + +root = os.path.dirname(os.path.abspath(__file__)) +print(root) + + + + + +os.chdir(root) +# os.system(cmd) +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + samples_dir = os.path.join(sub_dir, 'samples') + + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + if os.path.exists(new_crop_samples_dir): + continue + cmd = f'python prepare_data.py --input_dir {samples_dir}' + os.system(cmd) + # os.system(cmd) + + +# os.system(cmd) +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + samples_dir = os.path.join(sub_dir, 'samples') + + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + if os.path.exists(new_crop_samples_dir): + continue + cmd = f'python runmy.py --input_dir {samples_dir}' + os.system(cmd) + # os.system(cmd) + + +# os.system(cmd) +os.chdir(os.path.join(root,'detectron2/projects/DensePose')) +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + samples_dir = os.path.join(sub_dir, 'samples') + + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + if os.path.exists(new_crop_samples_dir): + continue + cmd = f'python apply_net.py show configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml R_101_FPN_DL_soft_s1x.pkl {samples_dir}/aligned_images dp_vertex --output {samples_dir}/seg --min_score 0.8' + os.system(cmd) + # os.system(cmd) + + +os.chdir(root) +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + + samples_dir = os.path.join(sub_dir, 'samples') + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + if os.path.exists(new_crop_samples_dir): + continue + cmd = f'python runmy_new_crop.py --input_dir {samples_dir}' + os.system(cmd) + # os.system(cmd) + + +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + samples_dir = os.path.join(sub_dir, 'samples') + if os.path.exists(samples_dir): + new_crop_samples_dir = os.path.join(sub_dir, 'samples_new_crop') + new_crop_mask_samples_dir = os.path.join(sub_dir, 'samples_new_crop/mask') + if os.path.exists(new_crop_mask_samples_dir): + continue + os.makedirs(new_crop_mask_samples_dir, exist_ok=True) + cmd = f'python segmentation_example.py --base_path {new_crop_samples_dir}' + os.system(cmd) + # os.system(cmd) + + + + + diff --git a/data_processing/readme.md b/data_processing/readme.md new file mode 100644 index 0000000..bb449f2 --- /dev/null +++ b/data_processing/readme.md @@ -0,0 +1 @@ +Our image processing code are mainly adapted from [hongsukchoi/3DCrowdNet_RELEASE](https://github.com/hongsukchoi/3DCrowdNet_RELEASE), thanks for their valuable contributions! \ No newline at end of file diff --git a/data_processing/requirements.txt b/data_processing/requirements.txt new file mode 100644 index 0000000..3a014fa --- /dev/null +++ b/data_processing/requirements.txt @@ -0,0 +1,21 @@ +opencv-contrib-python==4.7.0.72 +opencv-python==4.7.0.72 +opencv-python-headless==4.7.0.72 +urllib3==1.26.15 +json_tricks==3.17.3 +pytz==2022.7.1 +munkres==1.1.4 +scipy==1.9.1 +pyrender==0.1.45 +ImageHash==4.3.1 +easydict==1.10 +human-body-prior==0.8.5.0 +progressbar2==4.3b0 +numpy==1.21.6 +omegaconf==2.3.0 +Pillow==9.5.0 +pycocotools==2.0.6 +av==10.0.0 +pandas==1.3.5 +seaborn==0.12.2 +chumpy==0.70 diff --git a/data_processing/run_unsplash.py b/data_processing/run_unsplash.py new file mode 100644 index 0000000..59947f6 --- /dev/null +++ b/data_processing/run_unsplash.py @@ -0,0 +1,29 @@ +import os +import argparse + +parser = argparse.ArgumentParser(description=' ') +# general + +for i in range(0,1): + path = f'E:/project/unsplash/{i*1000:08d}' + head_box_path = f'{path}/head_bbox_yolov5_crowdhuman.json' + if not os.path.exists(path) or not os.path.exists(head_box_path): + continue + + os.chdir('E:/project/3DCrowdNet_upper_body-main/demo') + command =f'python extract_camera_parameter.py --gpu 0 --input_dir {path} --output_dir {path} --data_dir E:/project/3DCrowdNet_upper_body-main/data' + print(command) + os.system(command) + + os.chdir('E:/project/3DCrowdNet_upper_body-main/MANIQA') + command =f'python imagedups.py -r -d -N -p {path}/aligned_images' + print(command) + os.system(command) + + command = f'python remove_blurr_images.py --input_dir {path}' + print(command) + os.system(command) + + command = f'python delete_images.py --input_dir {path}' + print(command) + os.system(command) \ No newline at end of file diff --git a/data_processing/runmy.py b/data_processing/runmy.py new file mode 100644 index 0000000..458de19 --- /dev/null +++ b/data_processing/runmy.py @@ -0,0 +1,39 @@ +import os +import argparse + +# python runmy.py +parser = argparse.ArgumentParser() +parser.add_argument('--input_dir', type=str, default='') +args = parser.parse_args() + +path = args.input_dir + +root = os.path.dirname(os.path.abspath(__file__)) +print(root) + + + + +head_box_path = f'{path}/head_bbox_yolov5_crowdhuman.json' +if not os.path.exists(path) or not os.path.exists(head_box_path): + raise Exception('path or head_box_path not exists') + + +data_dir = os.path.join(root,'data') +os.chdir(os.path.join(root,'demo')) +command = f'python extract_camera_parameter.py --gpu 0 --input_dir {path} --output_dir {path} --data_dir {data_dir}' +print(command) +os.system(command) + +# os.chdir(os.path.join(root,'MANIQA')) +# command = f'python imagedups.py -r -d -N -p {path}/aligned_images' +# print(command) +# os.system(command) + +# command = f'python remove_blurr_images.py --input_dir {path}' +# print(command) +# os.system(command) +# +# command = f'python delete_images.py --input_dir {path}' +# print(command) +# os.system(command) diff --git a/data_processing/runmy_new_crop.py b/data_processing/runmy_new_crop.py new file mode 100644 index 0000000..c6f8671 --- /dev/null +++ b/data_processing/runmy_new_crop.py @@ -0,0 +1,20 @@ +import os +import argparse + +# python runmy.py +parser = argparse.ArgumentParser() +parser.add_argument('--input_dir', type=str, default='') +args = parser.parse_args() + +path = args.input_dir + +root = os.path.dirname(os.path.abspath(__file__)) +print(root) + + +os.chdir(os.path.join(root,'demo')) + +data_dir = os.path.join(root,'data') +command = f'python new_crop_use_densepose.py --gpu 0 --input_dir {path} --output_dir {path}_new_crop --data_dir {data_dir}' +print(command) +os.system(command) diff --git a/data_processing/segmentation_example.py b/data_processing/segmentation_example.py new file mode 100644 index 0000000..665abee --- /dev/null +++ b/data_processing/segmentation_example.py @@ -0,0 +1,102 @@ +import torch +from PIL import Image +from torchvision.transforms import ToPILImage +import glob +import os +from torchvision.models.segmentation import deeplabv3_resnet101 +from torchvision import transforms, utils +from tqdm import tqdm +from torch.utils.data import dataset + + +class LoadData(dataset.Dataset): + + def __init__(self, base_path): + super(LoadData, self).__init__() + #base_path = 'F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion/output/2023-10-28-with-inversion-initialization/samples_new_crop' + paths = sorted(glob.glob(f'{base_path}/aligned_images/*')) + os.makedirs(f'{base_path}/mask', exist_ok=True) + self.paths = paths + + def __getitem__(self,idx): + image_path =self.paths[idx] + image = Image.open(image_path) + # Define the preprocessing transformation + preprocess = transforms.Compose([ + transforms.ToTensor() + ]) + + # Apply the transformation to the image + input_tensor = preprocess(image) + + return input_tensor, image_path + + def __len__(self): + return len(self.paths) + + +def get_mask(model, batch, cid): + normalized_batch = transforms.functional.normalize( + batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + output = model(normalized_batch)['out'] + # sem_classes = [ + # '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + # 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + # 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' + # ] + # sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} + # cid = sem_class_to_idx['car'] + + normalized_masks = torch.nn.functional.softmax(output, dim=1) + + boolean_car_masks = (normalized_masks.argmax(1) == cid) + return boolean_car_masks.float() + + +def get_and_save_mask( device,base_path): + # data loder + batch_size = 8 + dataset = torch.utils.data.DataLoader( + dataset=LoadData(base_path), + batch_size=batch_size, + shuffle=False + ) + for input_tensor, image_paths in tqdm(dataset): + input_batch = input_tensor.to(device) # batxh, 3, 256, 256 + + # load segmentation net + seg_net = deeplabv3_resnet101(pretrained=True, progress=False).to(device) + seg_net.requires_grad_(False) + seg_net.eval() + + # 15 means human mask + mask = get_mask(seg_net, input_batch, 15) + print(mask.shape) # 16, 256, 256 + + mask = mask.unsqueeze(1) # 16, 1, 256, 256 + + for i in range(mask.shape[0]): + # Squeeze the tensor to remove unnecessary dimensions and convert to PIL Image + mask0 = mask[i:i+1] + mask_squeezed = torch.squeeze(mask0) + mask_image = ToPILImage()(mask_squeezed) + image_path = image_paths[i] + # Save as PNG + mask_path = image_path.replace('aligned_images', 'mask') + # /home/zjucadjin/dataset/pexels-256-new/0000000053/0000053992.png + # mask_dir = mask_path[:-len('/0000053992.png')] + # os.makedirs(mask_dir, exist_ok=True) + mask_image.save(mask_path) + + +def run(rank,base_path): + rank = rank + device = torch.device('cuda', rank) + get_and_save_mask(device,base_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--base_path', type=str, required=True) + run(0, parser.parse_args().base_path) \ No newline at end of file diff --git a/data_processing/tasks/4.19-4.20.xlsx b/data_processing/tasks/4.19-4.20.xlsx new file mode 100644 index 0000000..87afec1 Binary files /dev/null and b/data_processing/tasks/4.19-4.20.xlsx differ diff --git a/data_processing/tasks/4.20-4.21.xlsx b/data_processing/tasks/4.20-4.21.xlsx new file mode 100644 index 0000000..a71af8f Binary files /dev/null and b/data_processing/tasks/4.20-4.21.xlsx differ diff --git a/data_processing/tasks/4.21-4.22.xlsx b/data_processing/tasks/4.21-4.22.xlsx new file mode 100644 index 0000000..622c919 Binary files /dev/null and b/data_processing/tasks/4.21-4.22.xlsx differ diff --git a/data_processing/tasks/4.22-4.23.xlsx b/data_processing/tasks/4.22-4.23.xlsx new file mode 100644 index 0000000..deec8b1 Binary files /dev/null and b/data_processing/tasks/4.22-4.23.xlsx differ diff --git a/data_processing/tasks/4.23-4.24.xlsx b/data_processing/tasks/4.23-4.24.xlsx new file mode 100644 index 0000000..534b0b4 Binary files /dev/null and b/data_processing/tasks/4.23-4.24.xlsx differ diff --git a/data_processing/tasks/4.24-4.25.xlsx b/data_processing/tasks/4.24-4.25.xlsx new file mode 100644 index 0000000..405e12b Binary files /dev/null and b/data_processing/tasks/4.24-4.25.xlsx differ diff --git a/data_processing/tool/check_crowdidx.py b/data_processing/tool/check_crowdidx.py new file mode 100644 index 0000000..9f99ad8 --- /dev/null +++ b/data_processing/tool/check_crowdidx.py @@ -0,0 +1,241 @@ +import pickle + +import numpy as np +import os.path as osp +from pycocotools.coco import COCO + + +def compute_CrowdIndex(ref_bbox, ref_kps, intf_kps): + + na = 0 + for ref_kp in ref_kps: + count = get_inclusion(ref_bbox, ref_kp) + na += count + + nb = 0 + for intf_kp in intf_kps: + count = get_inclusion(ref_bbox, intf_kp) + nb += count + + if na < 4: # invalid ones, e.g. truncated images + return 0 + else: + return nb / na + + +def get_inclusion(bbox, kp): + if bbox[0] > kp[0] or (bbox[0] + bbox[2]) < kp[0]: + return 0 + + if bbox[1] > kp[1] or (bbox[1] + bbox[3]) < kp[1]: + return 0 + + return 1 + + +def compute_iou(src_roi, dst_roi): + # IoU calculate with GTs + xmin = np.maximum(dst_roi[:, 0], src_roi[:, 0]) + ymin = np.maximum(dst_roi[:, 1], src_roi[:, 1]) + xmax = np.minimum(dst_roi[:, 0] + dst_roi[:, 2], src_roi[:, 0] + src_roi[:, 2]) + ymax = np.minimum(dst_roi[:, 1] + dst_roi[:, 3], src_roi[:, 1] + src_roi[:, 3]) + + interArea = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) + + boxAArea = dst_roi[:, 2] * dst_roi[:, 3] + boxBArea = np.tile(src_roi[:, 2] * src_roi[:, 3], (len(dst_roi), 1)) + sumArea = boxAArea + boxBArea + + iou = interArea / (sumArea - interArea + 1e-5) + + return iou + + +def get_bbox(joint_img, joint_valid): + x_img, y_img = joint_img[:, 0], joint_img[:, 1] + x_img = x_img[joint_valid == 1]; + y_img = y_img[joint_valid == 1]; + xmin = min(x_img); + ymin = min(y_img); + xmax = max(x_img); + ymax = max(y_img); + + x_center = (xmin + xmax) / 2.; + width = xmax - xmin; + xmin = x_center - 0.5 * width * 1.2 + xmax = x_center + 0.5 * width * 1.2 + + y_center = (ymin + ymax) / 2.; + height = ymax - ymin; + ymin = y_center - 0.5 * height * 1.2 + ymax = y_center + 0.5 * height * 1.2 + + bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32) + return bbox + +class PW3D(): + def __init__(self, data_split): + self.data_split = data_split + self.data_path = osp.join('..', 'data', 'PW3D', 'data') + self.seq_iou_list, self.seq_crowd_idx_list = self.load_data() + + def load_data(self): + db = COCO(osp.join(self.data_path, '3DPW_latest_' + self.data_split + '.json')) + + seq_iou_list = {} + seq_crowd_idx_list = {} + for iid in db.imgs.keys(): + img = db.imgs[iid] + sequence_name = img['sequence'] + img_name = img['file_name'] + img_path = osp.join(self.data_path, 'imageFiles', sequence_name, img_name) + img_width, img_height = img['height'], img['width'] + + aids = db.getAnnIds(iid) + if len(aids) < 2: + continue + + data_dict = {} + data_dict['img_id'] = iid + data_dict['img_path'] = img_path + + # compute iou + ann1 = db.anns[aids[0]] + ann2 = db.anns[aids[1]] + + bbox1 = np.array(ann1['bbox']) + bbox2 = np.array(ann2['bbox']) + iou = compute_iou(bbox1[None, :], bbox2[None, :])[0,0] + + seq_iou_list.setdefault(sequence_name, []).append(iou) + + # compute crowd index + joint_img1 = np.array(ann1['joint_img'], dtype=np.float32).reshape(-1, 2) + joint_img2 = np.array(ann2['joint_img'], dtype=np.float32).reshape(-1, 2) + + ci1 = compute_CrowdIndex(bbox1, joint_img1, joint_img2) + ci2 = compute_CrowdIndex(bbox2, joint_img2, joint_img1) + + crowd_idx = (ci1+ci2) / 2 + + seq_crowd_idx_list.setdefault(sequence_name, []).append(crowd_idx) + + for seq in seq_iou_list.keys(): + seq_iou_list[seq] = sum(seq_iou_list[seq]) / len(seq_iou_list[seq]) + for seq in seq_crowd_idx_list.keys(): + seq_crowd_idx_list[seq] = sum(seq_crowd_idx_list[seq]) / len(seq_crowd_idx_list[seq]) + + return seq_iou_list, seq_crowd_idx_list + + def print_statistics(self): + total_mean_iou, total_mean_crowd_idx = 0, 0 + for seq in self.seq_iou_list: + print(f"Average iou / crowd index of {seq}: {self.seq_iou_list[seq]}, {self.seq_crowd_idx_list[seq]}") + total_mean_iou += self.seq_iou_list[seq] + total_mean_crowd_idx += self.seq_crowd_idx_list[seq] + print(f"All iou / crowd index: {total_mean_iou/len(self.seq_iou_list)}, {total_mean_crowd_idx/len(self.seq_iou_list)}") + + +class MuPoTs(): + def __init__(self): + self.test_annot_path = osp.join('..', 'data', 'MuPoTs', 'data', 'MuPoTS-3D.json') + self.seq_iou_list, self.seq_crowd_idx_list = self.load_data() + + def load_data(self): + db = COCO(self.test_annot_path) + + seq_iou_list = {} + seq_crowd_idx_list = {} + for iid in db.imgs.keys(): + img = db.imgs[iid] + img_name = img['file_name'] + sequence_name = img_name.split('/')[0] + + aids = db.getAnnIds(iid) + if len(aids) < 2: + continue + + for aid_idx in range(len(aids)): + ref_ann = db.anns[aids[aid_idx]] + other_aids = aids[:aid_idx] + aids[aid_idx+1:] + ref_bbox = np.array(ref_ann['bbox']) + ref_joint = np.array(ref_ann['keypoints_img']) + for oaid in other_aids: + other_ann = db.anns[oaid] + other_bbox = np.array(other_ann['bbox']) + other_joint = np.array(other_ann['keypoints_img']) + + iou = compute_iou(ref_bbox[None, :], other_bbox[None, :])[0, 0] / 2.0 # compensate twice count + crowd_idx = compute_CrowdIndex(ref_bbox, ref_joint, other_joint) + + seq_iou_list.setdefault(sequence_name, []).append(iou) + seq_crowd_idx_list.setdefault(sequence_name, []).append(crowd_idx) + + for seq in seq_iou_list.keys(): + seq_iou_list[seq] = sum(seq_iou_list[seq]) / len(seq_iou_list[seq]) + for seq in seq_crowd_idx_list.keys(): + seq_crowd_idx_list[seq] = sum(seq_crowd_idx_list[seq]) / len(seq_crowd_idx_list[seq]) + + return seq_iou_list, seq_crowd_idx_list + + def print_statistics(self): + total_mean_iou, total_mean_crowd_idx = 0, 0 + for seq in self.seq_iou_list: + print(f"Average iou / crowd index of {seq}: {self.seq_iou_list[seq]}, {self.seq_crowd_idx_list[seq]}") + total_mean_iou += self.seq_iou_list[seq] + total_mean_crowd_idx += self.seq_crowd_idx_list[seq] + print(f"All iou / crowd index: {total_mean_iou/len(self.seq_iou_list)}, {total_mean_crowd_idx/len(self.seq_iou_list)}") + + + +class CMUP(): + def __init__(self): + self.seq_list = ['160906_pizza1', '160422_ultimatum1', '160422_haggling1', '160422_mafia2'] + + self.seq_iou_list, self.seq_crowd_idx_list = {}, {} + + for seq_name in self.seq_list: + self.annot_path = osp.join('..', 'data', 'CMU-Panoptic', 'data', f'{seq_name}.pkl') + mean_iou, mean_crowd_idx = self.load_data() + self.seq_iou_list[seq_name], self.seq_crowd_idx_list[seq_name] = mean_iou, mean_crowd_idx + + def load_data(self): + with open(self.annot_path,'rb') as f: + db = pickle.load(f) + + seq_iou_list = [] + seq_crowd_idx_list = [] + for img_idx in range(len(db)): + + for i in range(db[img_idx]['kpts2d'].shape[0]): + ref_joint = db[img_idx]['kpts2d'][i] #24 3 + ref_bbox = get_bbox(ref_joint[:, :2], ref_joint[:, 2]) + + other_joints = np.concatenate((db[img_idx]['kpts2d'][:i],db[img_idx]['kpts2d'][i+1:]), axis=0) + for other_joint in other_joints: + other_bbox = get_bbox(other_joint[:, :2], other_joint[:, 2]) + iou = compute_iou(ref_bbox[None, :], other_bbox[None, :])[0, 0] / 2.0 # compensate twice count + crowd_idx = compute_CrowdIndex(ref_bbox, ref_joint, other_joint) + + seq_iou_list.append(iou) + seq_crowd_idx_list.append(crowd_idx) + + mean_iou, mean_crowd_idx = sum(seq_iou_list) / len(seq_iou_list), sum(seq_crowd_idx_list) / len(seq_crowd_idx_list) + return mean_iou, mean_crowd_idx + + def print_statistics(self): + total_mean_iou, total_mean_crowd_idx = 0, 0 + for seq in self.seq_list: + print(f"Average iou / crowd index of {seq}: {self.seq_iou_list[seq]}, {self.seq_crowd_idx_list[seq]}") + total_mean_iou += self.seq_iou_list[seq] + total_mean_crowd_idx += self.seq_crowd_idx_list[seq] + print(f"All iou / crowd index: {total_mean_iou/len(self.seq_list)}, {total_mean_crowd_idx/len(self.seq_list)}") + + + + +if __name__ == '__main__': + dataset = PW3D('validation') + # dataset = MuPoTs() + # dataset = CMUP() + dataset.print_statistics() diff --git a/data_processing/tool/convert_simple_to_i2l.py b/data_processing/tool/convert_simple_to_i2l.py new file mode 100644 index 0000000..f55c53f --- /dev/null +++ b/data_processing/tool/convert_simple_to_i2l.py @@ -0,0 +1,9 @@ +import torch + +model = torch.load('pose_resnet_50_256x192.pth.tar') # load simple +model_save = {'network': {}, 'epoch': 0} +for k,v in model.items(): + save_k = 'module.backbone.' + k + model_save['network'][save_k] = v.cpu() + +torch.save(model_save, 'snapshot_0.pth.tar') diff --git a/data_processing/tool/match_3dpw_2dpose.py b/data_processing/tool/match_3dpw_2dpose.py new file mode 100644 index 0000000..19d557c --- /dev/null +++ b/data_processing/tool/match_3dpw_2dpose.py @@ -0,0 +1,373 @@ +import os.path as osp +import torch +import numpy as np +import copy +import cv2 +import json +import torchvision.transforms as transforms +import torch.nn as nn +from torch.utils.data import DataLoader +from pycocotools.coco import COCO +from tqdm import tqdm +import matplotlib.pyplot as plt +import matplotlib as mpl +import glob + + + +class PW3D(torch.utils.data.Dataset): + def __init__(self, get_crowd): + self.get_crowd = get_crowd + self.data_split = 'validation' if self.get_crowd else 'test' # data_split + self.data_path = osp.join('..', 'data', 'PW3D', 'data') + + self.coco_joints_name = ( + 'Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', + 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle') # 17 + self.openpose_joints_name = ( + 'Nose', 'Neck', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', + 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Eye', 'L_Eye', 'R_Ear', 'L_Ear') # 18 + # Neck??? + self.openpose_joints_name = ('Nose', 'Neck', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Eye', 'L_Eye', 'R_Ear', 'L_Ear', 'Pelvis') + + self.smpl_joints_name = ('Pelvis', 'L_Hip', 'R_Hip', 'Torso', 'L_Knee', 'R_Knee', 'Spine', 'L_Ankle', 'R_Ankle', 'Chest', 'L_Toe', 'R_Toe', 'Neck', 'L_Thorax', 'R_Thorax', 'Head', 'L_Shoulder', 'R_Shoulder', + 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hand', 'R_Hand') + self.coco_skeleton = ( (1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12) ) + + self.datalist = self.load_data() + print("data len: ", len(self.datalist)) + + def load_data(self): + db = COCO(osp.join(self.data_path, '3DPW_latest_' + self.data_split + '.json')) + + if self.get_crowd: + with open(osp.join(self.data_path, f'3DPW_{self.data_split}_crowd_yolo_result.json')) as f: + yolo_bbox = json.load(f) + else: + with open(osp.join(self.data_path, '3DPW_test_yolo_result.json')) as f: + yolo_bbox = json.load(f) + + datalist = [] + aid_keys = sorted(yolo_bbox.keys(), key=lambda x: int(x)) if self.get_crowd else db.anns.keys() + for aid in aid_keys: + aid = int(aid) + ann = db.anns[aid] + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + sequence_name = img['sequence'] + img_name = img['file_name'] + img_path = osp.join(self.data_path, 'imageFiles', sequence_name, img_name) + cam_param = {k: np.array(v, dtype=np.float32) for k, v in img['cam_param'].items()} + + openpose = np.array(ann['openpose_result'], dtype=np.float32).reshape(-1, 3) + openpose = transform_joint_to_other_db(openpose, self.openpose_joints_name, self.coco_joints_name) + + """ + # TEMP + centerpose = temp_result[str(aid)]['coco_joints'] + centerpose = np.array(centerpose).reshape(-1,2) + + tmpimg = cv2.imread(img_path) + oimg = vis_keypoints(tmpimg, openpose) + cv2.imshow('openpose', oimg/255) + cv2.waitKey(0) + cimg = vis_keypoints(tmpimg, centerpose) + cv2.imshow('centerpose', cimg / 255) + cv2.waitKey(0) + import pdb; pdb.set_trace() + """ + + smpl_joints = np.array(ann['joint_img']).reshape(-1,2) + smpl_joints = np.concatenate((smpl_joints, np.ones_like(smpl_joints[:, :1])), axis=1) + bbox = get_bbox(smpl_joints, np.ones_like(smpl_joints[:, 0]), extend_ratio=1.1) + bbox[2], bbox[3] = bbox[0] + bbox[2], bbox[1] + bbox[3] + + smplpose = transform_joint_to_other_db(smpl_joints, self.smpl_joints_name, self.coco_joints_name) + + img_name = sequence_name + '_' + img_name + data_dict = {'img_path': img_path, 'img_name': img_name, 'img_id': image_id, 'ann_id': aid, + 'img_shape': (img['height'], img['width']), + 'bbox': bbox, 'openpose': openpose, 'smplpose': smplpose} + + datalist.append(data_dict) + + return datalist + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + pass + + def getitem(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_name, img_shape, img_id, aid = data['img_name'], data['img_shape'], data['img_id'], data['ann_id'] + + # for prediction matching + openpose, smplpose = data['openpose'], data['smplpose'] + + # img_path = data['img_path'] + # tmpimg = cv2.imread(img_path) + # oimg = vis_keypoints_with_skeleton(tmpimg, openpose.T, self.coco_skeleton) + # cv2.imshow('openpose', oimg/255) + # cv2.waitKey(0) + # simg = vis_keypoints_with_skeleton(tmpimg, smplpose.T, self.coco_skeleton) + # cv2.imshow('smplpose', simg / 255) + # cv2.waitKey(0) + # import pdb; pdb.set_trace() + + return data['img_path'], img_name, img_id, aid, openpose, smplpose + + +class PoseMatcher: + def __init__(self, dataloader): + self.dataloader = dataloader + result_path = '/home/hongsukchoi/projects/pytorch_Realtime_Multi-Person_Pose_Estimation' # '/home/redarknight/projects/HHRNet/output/3dpw/test' + self.candidates = self.load_2dpose_results(result_path) + + def run(self): + output_list = [] + for idx in range(len(self.dataloader)): + img_path, img_name, img_id, aid, openpose, smplpose = self.dataloader.getitem(idx) + candidates = self.candidates[img_name] + + output = {} + output['candidates'] = candidates + output['target'] = { + 'openpose': openpose, + 'smplpose': smplpose + } + output['meta'] = { + 'aid': aid, + 'img_id': img_id, + 'img_path': img_path + } + + output_list.append(output) + + output_list = filter_bbox(output_list) + + save_output(output_list) + + def load_2dpose_results(self, result_path): + result_jsons = glob.glob(f'{result_path}/*.json') + + hhrnet_results = {} + for rj in result_jsons: + with open(rj) as f: + pose_outputs = json.load(f) + + prefix = 'openpose_result_' # 'higher_hrnet_result_' + seq_name = rj.split(prefix)[-1][:-5] + for img_name in sorted(pose_outputs.keys()): + pose_candidates = pose_outputs[img_name] + try: + pose_candidates = np.asarray(pose_candidates, dtype=np.float32)[:,:,:3] + except IndexError: # when the result is empty + pose_candidates = [] + img_name = seq_name + '_' + img_name + + hhrnet_results[img_name] = pose_candidates + + return hhrnet_results + + +# open pose valid joint compare +def filter_bbox(output_list): + result = {} + for out in output_list: + candidates = out['candidates'] + openpose_from_dataset = out['target']['openpose'] + smplpose_from_dataset = out['target']['smplpose'] + aid = out['meta']['aid'] + img_id = out['meta']['img_id'] + img_path = out['meta']['img_path'] + + if len(candidates) == 0: + continue + + valid_openpose_joints = (openpose_from_dataset[:, 2] > 0.1) # eye has low scores, 17: [1,1,1,...0,0] + valid_smplpose_joints = (smplpose_from_dataset[:, 2] > 0.0) + ref_bbox = get_bbox(smplpose_from_dataset, valid_smplpose_joints, 1.0) + ref_err = min(ref_bbox[2], ref_bbox[3]) * (1/15) + + match_idx = 0 + err = ref_err # pixel + for idx in range(len(candidates)): + pred_pose = candidates[idx] + valid_pred_joints = (pred_pose[:, 2] > 0.1) + valid_idx = (valid_smplpose_joints * valid_pred_joints).nonzero()[0] + l1_err = np.abs(pred_pose[valid_idx, :2] - smplpose_from_dataset[valid_idx, :2]) + if l1_err.size == 0: + continue + + euc_dst = np.sqrt((l1_err**2).sum(axis=1)).mean() + + if euc_dst < err: + match_idx = idx + err = euc_dst + + if err == ref_err: + continue + """ + coco_skeleton = ((1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12)) + tmpimg = cv2.imread(img_path) + oimg = vis_keypoints(tmpimg, openpose_from_dataset) #vis_keypoints_with_skeleton(tmpimg, openpose_from_dataset.T, coco_skeleton, kp_thresh=0.0) + cv2.imshow('openpose', oimg/255) + cv2.waitKey(0) + # cv2.destroyAllWindows() + # cv2.waitKey(1) + # simg = vis_keypoints(tmpimg, smplpose_from_dataset) #vis_keypoints_with_skeleton(tmpimg, smplpose_from_dataset.T, coco_skeleton, kp_thresh=0.0) + # cv2.imshow('smplpose', simg / 255) + # cv2.waitKey(0) + for idx in range(len(candidates)): + pimg = vis_keypoints(tmpimg, candidates[idx]) #vis_keypoints_with_skeleton(tmpimg, candidates[idx].T, coco_skeleton, kp_thresh=0.0) + cv2.imshow(f'crowdpose {idx}', pimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + + res = {} + res['coco_joints'] = candidates[match_idx].tolist() # 17 x2 + res['img_id'] = img_id + result[aid] = res + + print("Before filter: ", len(output_list), " After filter: ", len(result)) + + return result + + +def save_output(output): + save_file_name = f'3DPW_test_hhrnet_result.json' + print("Saving result to ", save_file_name) + with open(save_file_name, 'w') as f: + json.dump(output, f) + + +def bbox_iou(box1, box2): + """ + Returns the IoU of two bounding boxes + + + """ + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + + # get the corrdinates of the intersection rectangle + inter_rect_x1 = torch.max(b1_x1, b2_x1) + inter_rect_y1 = torch.max(b1_y1, b2_y1) + inter_rect_x2 = torch.min(b1_x2, b2_x2) + inter_rect_y2 = torch.min(b1_y2, b2_y2) + + # Intersection area + device = box1.device + inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) * torch.max( + inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) + + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + + iou = inter_area / (b1_area + b2_area - inter_area) + + return iou + +def get_bbox(joint_img, joint_valid, extend_ratio=1.2): + x_img, y_img = joint_img[:, 0], joint_img[:, 1] + # x_img = x_img[joint_valid==1]; y_img = y_img[joint_valid==1]; + x_img = x_img[joint_valid > 0.01]; + y_img = y_img[joint_valid > 0.01]; + + xmin = min(x_img); + ymin = min(y_img); + xmax = max(x_img); + ymax = max(y_img); + + x_center = (xmin + xmax) / 2.; + width = xmax - xmin; + xmin = x_center - 0.5 * width * extend_ratio + xmax = x_center + 0.5 * width * extend_ratio + + y_center = (ymin + ymax) / 2.; + height = ymax - ymin; + ymin = y_center - 0.5 * height * extend_ratio + ymax = y_center + 0.5 * height * extend_ratio + + bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32) + return bbox + +def transform_joint_to_other_db(src_joint, src_name, dst_name): + src_joint_num = len(src_name) + dst_joint_num = len(dst_name) + + new_joint = np.zeros(((dst_joint_num,) + src_joint.shape[1:]), dtype=np.float32) + for src_idx in range(len(src_name)): + name = src_name[src_idx] + if name in dst_name: + dst_idx = dst_name.index(name) + new_joint[dst_idx] = src_joint[src_idx] + + return new_joint + + +def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1, kps_scores=None): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for l in range(len(kps_lines)): + i1 = kps_lines[l][0] + i2 = kps_lines[l][1] + p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32) + p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32) + if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: + cv2.line( + kp_mask, p1, p2, + color=colors[l], thickness=2, lineType=cv2.LINE_AA) + if kps[2, i1] > kp_thresh: + cv2.circle( + kp_mask, p1, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + if kps[2, i2] > kp_thresh: + cv2.circle( + kp_mask, p2, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + + if kps_scores is not None: + cv2.putText(kp_mask, str(kps_scores[i2, 0]), p2, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255)) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +def vis_keypoints(img, kps, alpha=1): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for i in range(len(kps)): + p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32) + cv2.circle(kp_mask, p, radius=3, color=colors[i], thickness=-1, lineType=cv2.LINE_AA) + cv2.putText(kp_mask, str(i), p, cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +if __name__ == '__main__': + testset_loader = PW3D(get_crowd=False) + pose_matcher = PoseMatcher(testset_loader) + pose_matcher.run() \ No newline at end of file diff --git a/data_processing/tool/match_mupots_2dpose.py b/data_processing/tool/match_mupots_2dpose.py new file mode 100644 index 0000000..9219610 --- /dev/null +++ b/data_processing/tool/match_mupots_2dpose.py @@ -0,0 +1,339 @@ +import os.path as osp +import torch +import numpy as np +import copy +import cv2 +import json +import torchvision.transforms as transforms +import torch.nn as nn +from torch.utils.data import DataLoader +from pycocotools.coco import COCO +from tqdm import tqdm +import matplotlib.pyplot as plt +import matplotlib as mpl +import glob + + +class MuPoTs(torch.utils.data.Dataset): + def __init__(self): + self.data_split = 'test' + self.img_dir = osp.join('..', 'data', 'MuPoTs', 'data', 'MultiPersonTestSet') + self.test_annot_path = osp.join('..', 'data', 'MuPoTs', 'data', 'MuPoTS-3D.json') + + # MuPoTS + self.mupots_joint_num = 17 + self.mupots_joints_name = ('Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine', 'Head') # + self.mupots_flip_pairs = ((2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13)) + self.mupots_skeleton = ((0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9), (9, 10), (11, 12), (12, 13), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7)) + + # MSCOCO + self.coco_joint_num = 17 + self.coco_joints_name = ('Head', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle') + self.coco_skeleton = ((1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12)) + self.coco_flip_pairs = ((1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16)) + + self.openpose_joints_name = ('Nose', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Eye', 'L_Eye', 'R_Ear', 'L_Ear') + + self.datalist = self.load_data() + print('mupots data len: ', len(self.datalist)) + + def load_data(self): + data = [] + db = COCO(self.test_annot_path) + + print("Get bounding box and root from groundtruth") + for aid in db.anns.keys(): + ann = db.anns[aid] + if ann['is_valid'] == 0: + continue + + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + file_name = img['file_name'] + img_path = osp.join(self.img_dir, file_name) + img_name = '_'.join(file_name.split('/')) + + joint_img = np.array(ann['keypoints_img']) + joint_img = np.concatenate((joint_img, np.ones_like(joint_img[:, :1])), axis=1) + joint_img = transform_joint_to_other_db(joint_img, self.mupots_joints_name, self.openpose_joints_name) # self.coco_joints_name + + data.append({ + 'img_name': img_name, + 'img_path': img_path, + 'img_id': image_id, + 'ann_id': aid, + 'joint_img': joint_img + }) + + return data + + def __len__(self): + return len(self.datalist) + + def getitem(self, idx): + data = copy.deepcopy(self.datalist[idx]) + img_name, img_id, aid = data['img_name'], data['img_id'], data['ann_id'] + + # for prediction matching + gtpose = data['joint_img'] + + # img_path = data['img_path'] + # tmpimg = cv2.imread(img_path) + # newimg = vis_keypoints_with_skeleton(tmpimg, gtpose.T, self.coco_skeleton) + # cv2.imshow('img w gt pose', newimg/255) + # cv2.waitKey(0) + # import pdb; pdb.set_trace() + + return data['img_path'], img_name, img_id, aid, gtpose + + +class PoseMatcher: + def __init__(self, dataloader): + self.dataloader = dataloader + result_path = '/home/hongsukchoi/projects/pytorch_Realtime_Multi-Person_Pose_Estimation' # '/home/redarknight/projects/HHRNet' + self.candidates = self.load_2dpose_results(result_path) + + def run(self): + output_list = [] + for idx in range(len(self.dataloader)): + img_path, img_name, img_id, aid, gtpose = self.dataloader.getitem(idx) + candidates = self.candidates[img_name] + + output = {} + output['candidates'] = candidates + output['target'] = { + 'gtpose': gtpose + } + output['meta'] = { + 'aid': aid, + 'img_id': img_id, + 'img_path': img_path + } + + output_list.append(output) + + output_list = filter_bbox(output_list) + + save_output(output_list) + + def load_2dpose_results(self, result_path): + result_jsons = glob.glob(f'{result_path}/*.json') + + detector_results = {} + for rj in result_jsons: + with open(rj) as f: + pose_outputs = json.load(f) + + seq_name = rj.split('openpose_result_')[-1][:-5] + for img_name in sorted(pose_outputs.keys()): + pose_candidates = pose_outputs[img_name] + try: + pose_candidates = np.asarray(pose_candidates, dtype=np.float32)[:,:18,:3] + except IndexError: # when the result is empty + pose_candidates = [] + img_name = seq_name + '_' + img_name + + detector_results[img_name] = pose_candidates + + return detector_results + + +# open pose valid joint compare +def filter_bbox(output_list): + result = {} + d_count = 0 + for out in output_list: + candidates = out['candidates'] + gtpose_from_dataset = out['target']['gtpose'] + aid = out['meta']['aid'] + img_id = out['meta']['img_id'] + img_path = out['meta']['img_path'] + + if len(candidates) == 0: + continue + + valid_gtpose_joints = (gtpose_from_dataset[:, 2] > 0.0) + + match_idx = 0 + err = 60 # pixel + for idx in range(len(candidates)): + pred_pose = candidates[idx] + valid_pred_joints = (pred_pose[:, 2] > 0.1) + valid_idx = (valid_gtpose_joints * valid_pred_joints).nonzero()[0] + l1_err = np.abs(pred_pose[valid_idx, :2] - gtpose_from_dataset[valid_idx, :2]) + # import pdb; pdb.set_trace() + + euc_dst = np.sqrt((l1_err**2).sum(axis=1)).mean() + + if euc_dst < err: + match_idx = idx + err = euc_dst + + if err == 60: + # continue + + candidates[match_idx] = np.zeros_like(candidates[match_idx]) + d_count += 1 + """ + coco_skeleton = ((1, 2), (0, 1), (0, 2), (2, 4), (1, 3), (6, 8), (8, 10), (5, 7), (7, 9), (12, 14), (14, 16), (11, 13), (13, 15), (5, 6), (11, 12)) + tmpimg = cv2.imread(img_path) + gimg = vis_keypoints(tmpimg, gtpose_from_dataset) #vis_keypoints_with_skeleton(tmpimg, openpose_from_dataset.T, coco_skeleton, kp_thresh=0.0) + cv2.imshow('gtpose', gimg/255) + cv2.waitKey(0) + # cv2.destroyAllWindows() + # cv2.waitKey(1) + # simg = vis_keypoints(tmpimg, smplpose_from_dataset) #vis_keypoints_with_skeleton(tmpimg, smplpose_from_dataset.T, coco_skeleton, kp_thresh=0.0) + # cv2.imshow('smplpose', simg / 255) + # cv2.waitKey(0) + print("Match idx: ", match_idx) + for idx in range(len(candidates)): + pimg = vis_keypoints(tmpimg, candidates[idx]) #vis_keypoints_with_skeleton(tmpimg, candidates[idx].T, coco_skeleton, kp_thresh=0.0) + cv2.imshow(f'hhrnetpose {idx}', pimg) + cv2.waitKey(0) + cv2.destroyAllWindows() + cv2.waitKey(1) + import pdb; pdb.set_trace() + """ + + res = {} + res['coco_joints'] = candidates[match_idx].tolist() # 17 x2 + res['img_id'] = img_id + result[aid] = res + + print(f"{d_count} dummy out of {len(output_list)}") + + return result + + +def save_output(output): + save_file_name = f'MuPoTs_test_openpose_result.json' + print("Saving result to ", save_file_name) + with open(save_file_name, 'w') as f: + json.dump(output, f) + + +def bbox_iou(box1, box2): + """ + Returns the IoU of two bounding boxes + + + """ + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + + # get the corrdinates of the intersection rectangle + inter_rect_x1 = torch.max(b1_x1, b2_x1) + inter_rect_y1 = torch.max(b1_y1, b2_y1) + inter_rect_x2 = torch.min(b1_x2, b2_x2) + inter_rect_y2 = torch.min(b1_y2, b2_y2) + + # Intersection area + device = box1.device + inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) * torch.max( + inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) + + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + + iou = inter_area / (b1_area + b2_area - inter_area) + + return iou + +def get_bbox(joint_img, joint_valid, extend_ratio=1.2): + x_img, y_img = joint_img[:, 0], joint_img[:, 1] + # x_img = x_img[joint_valid==1]; y_img = y_img[joint_valid==1]; + x_img = x_img[joint_valid > 0.01]; + y_img = y_img[joint_valid > 0.01]; + + xmin = min(x_img); + ymin = min(y_img); + xmax = max(x_img); + ymax = max(y_img); + + x_center = (xmin + xmax) / 2.; + width = xmax - xmin; + xmin = x_center - 0.5 * width * extend_ratio + xmax = x_center + 0.5 * width * extend_ratio + + y_center = (ymin + ymax) / 2.; + height = ymax - ymin; + ymin = y_center - 0.5 * height * extend_ratio + ymax = y_center + 0.5 * height * extend_ratio + + bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32) + return bbox + +def transform_joint_to_other_db(src_joint, src_name, dst_name): + src_joint_num = len(src_name) + dst_joint_num = len(dst_name) + + new_joint = np.zeros(((dst_joint_num,) + src_joint.shape[1:]), dtype=np.float32) + for src_idx in range(len(src_name)): + name = src_name[src_idx] + if name in dst_name: + dst_idx = dst_name.index(name) + new_joint[dst_idx] = src_joint[src_idx] + + return new_joint + + +def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1, kps_scores=None): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for l in range(len(kps_lines)): + i1 = kps_lines[l][0] + i2 = kps_lines[l][1] + p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32) + p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32) + if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: + cv2.line( + kp_mask, p1, p2, + color=colors[l], thickness=2, lineType=cv2.LINE_AA) + if kps[2, i1] > kp_thresh: + cv2.circle( + kp_mask, p1, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + if kps[2, i2] > kp_thresh: + cv2.circle( + kp_mask, p2, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + + if kps_scores is not None: + cv2.putText(kp_mask, str(kps_scores[i2, 0]), p2, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255)) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +def vis_keypoints(img, kps, alpha=1): + # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. + cmap = plt.get_cmap('rainbow') + colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)] + colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw the keypoints. + for i in range(len(kps)): + p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32) + cv2.circle(kp_mask, p, radius=3, color=colors[i], thickness=-1, lineType=cv2.LINE_AA) + cv2.putText(kp_mask, str(i), p, cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) + + +if __name__ == '__main__': + testset_loader = MuPoTs() + pose_matcher = PoseMatcher(testset_loader) + pose_matcher.run() \ No newline at end of file diff --git a/data_processing/yolov5_crowdhuman/.dockerignore b/data_processing/yolov5_crowdhuman/.dockerignore new file mode 100644 index 0000000..3c6b6ab --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.dockerignore @@ -0,0 +1,216 @@ +# Repo-specific DockerIgnore ------------------------------------------------------------------------------------------- +#.git +.cache +.idea +runs +output +coco +storage.googleapis.com + +data/samples/* +**/results*.txt +*.jpg + +# Neural Network weights ----------------------------------------------------------------------------------------------- +**/*.weights +**/*.pt +**/*.pth +**/*.onnx +**/*.mlmodel +**/*.torchscript + + +# Below Copied From .gitignore ----------------------------------------------------------------------------------------- +# Below Copied From .gitignore ----------------------------------------------------------------------------------------- + + +# GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +wandb/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv* +venv*/ +ENV*/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + + +# https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- + +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon +Icon? + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + + +# https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/* +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries +.html # Bokeh Plots +.pg # TensorFlow Frozen Graphs +.avi # videos + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ +cmake-build-release/ + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties diff --git a/data_processing/yolov5_crowdhuman/.gitattributes b/data_processing/yolov5_crowdhuman/.gitattributes new file mode 100644 index 0000000..dad4239 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.gitattributes @@ -0,0 +1,2 @@ +# this drop notebooks from GitHub language stats +*.ipynb linguist-vendored diff --git a/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/bug-report.md b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 0000000..362059b --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,55 @@ +--- +name: "🐛 Bug report" +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +Before submitting a bug report, please be aware that your issue **must be reproducible** with all of the following, otherwise it is non-actionable, and we can not help you: + - **Current repo**: run `git fetch && git status -uno` to check and `git pull` to update repo + - **Common dataset**: coco.yaml or coco128.yaml + - **Common environment**: Colab, Google Cloud, or Docker image. See https://github.com/ultralytics/yolov5#environments + +If this is a custom dataset/training question you **must include** your `train*.jpg`, `test*.jpg` and `results.png` figures, or we can not help you. You can generate these with `utils.plot_results()`. + + +## 🐛 Bug +A clear and concise description of what the bug is. + + +## To Reproduce (REQUIRED) + +Input: +``` +import torch + +a = torch.tensor([5]) +c = a / 0 +``` + +Output: +``` +Traceback (most recent call last): + File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code + exec(code_obj, self.user_global_ns, self.user_ns) + File "", line 5, in + c = a / 0 +RuntimeError: ZeroDivisionError +``` + + +## Expected behavior +A clear and concise description of what you expected to happen. + + +## Environment +If applicable, add screenshots to help explain your problem. + + - OS: [e.g. Ubuntu] + - GPU [e.g. 2080 Ti] + + +## Additional context +Add any other context about the problem here. diff --git a/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/feature-request.md b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000..87db3ea --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,27 @@ +--- +name: "🚀 Feature request" +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +## 🚀 Feature + + +## Motivation + + + +## Pitch + + + +## Alternatives + + + +## Additional context + + diff --git a/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/question.md b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000..2c22aea --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,13 @@ +--- +name: "❓Question" +about: Ask a general question +title: '' +labels: question +assignees: '' + +--- + +## ❔Question + + +## Additional context diff --git a/data_processing/yolov5_crowdhuman/.github/dependabot.yml b/data_processing/yolov5_crowdhuman/.github/dependabot.yml new file mode 100644 index 0000000..9910689 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: +- package-ecosystem: pip + directory: "/" + schedule: + interval: weekly + time: "04:00" + open-pull-requests-limit: 10 + reviewers: + - glenn-jocher + labels: + - dependencies diff --git a/data_processing/yolov5_crowdhuman/.github/workflows/ci-testing.yml b/data_processing/yolov5_crowdhuman/.github/workflows/ci-testing.yml new file mode 100644 index 0000000..df50847 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/workflows/ci-testing.yml @@ -0,0 +1,80 @@ +name: CI CPU testing + +on: # https://help.github.com/en/actions/reference/events-that-trigger-workflows + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '0 0 * * *' # Runs at 00:00 UTC every day + +jobs: + cpu-tests: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.8] + model: ['yolov5s'] # models to test + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 50 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get pip cache + id: pip-cache + run: | + python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" + + - name: Cache pip + uses: actions/cache@v1 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.python-version }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -qr requirements.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -q onnx + python --version + pip --version + pip list + shell: bash + + - name: Download data + run: | + # curl -L -o tmp.zip https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip + # unzip -q tmp.zip -d ../ + # rm tmp.zip + + - name: Tests workflow + run: | + # export PYTHONPATH="$PWD" # to run '$ python *.py' files in subdirectories + di=cpu # inference devices # define device + + # train + python train.py --img 128 --batch 16 --weights weights/${{ matrix.model }}.pt --cfg models/${{ matrix.model }}.yaml --epochs 1 --device $di + # detect + python detect.py --weights weights/${{ matrix.model }}.pt --device $di + python detect.py --weights runs/train/exp/weights/last.pt --device $di + # test + python test.py --img 128 --batch 16 --weights weights/${{ matrix.model }}.pt --device $di + python test.py --img 128 --batch 16 --weights runs/train/exp/weights/last.pt --device $di + + python hubconf.py # hub + python models/yolo.py --cfg models/${{ matrix.model }}.yaml # inspect + python models/export.py --img 128 --batch 1 --weights weights/${{ matrix.model }}.pt # export + shell: bash diff --git a/data_processing/yolov5_crowdhuman/.github/workflows/codeql-analysis.yml b/data_processing/yolov5_crowdhuman/.github/workflows/codeql-analysis.yml new file mode 100644 index 0000000..1f07888 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/workflows/codeql-analysis.yml @@ -0,0 +1,54 @@ +# This action runs GitHub's industry-leading static analysis engine, CodeQL, against a repository's source code to find security vulnerabilities. +# https://github.com/github/codeql-action + +name: "CodeQL" + +on: + schedule: + - cron: '0 0 1 * *' # Runs at 00:00 UTC on the 1st of every month + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] + # Learn more: + # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/data_processing/yolov5_crowdhuman/.github/workflows/greetings.yml b/data_processing/yolov5_crowdhuman/.github/workflows/greetings.yml new file mode 100644 index 0000000..ee47229 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/workflows/greetings.yml @@ -0,0 +1,56 @@ +name: Greetings + +on: [pull_request_target, issues] + +jobs: + greeting: + runs-on: ubuntu-latest + steps: + - uses: actions/first-interaction@v1 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + pr-message: | + 👋 Hello @${{ github.actor }}, thank you for submitting a 🚀 PR! To allow your work to be integrated as seamlessly as possible, we advise you to: + - ✅ Verify your PR is **up-to-date with origin/master.** If your PR is behind origin/master an automatic [GitHub actions](https://github.com/ultralytics/yolov5/blob/master/.github/workflows/rebase.yml) rebase may be attempted by including the /rebase command in a comment body, or by running the following code, replacing 'feature' with the name of your local branch: + ```bash + git remote add upstream https://github.com/ultralytics/yolov5.git + git fetch upstream + git checkout feature # <----- replace 'feature' with local branch name + git rebase upstream/master + git push -u origin -f + ``` + - ✅ Verify all Continuous Integration (CI) **checks are passing**. + - ✅ Reduce changes to the absolute **minimum** required for your bug fix or feature addition. _"It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is."_ -Bruce Lee + + issue-message: | + 👋 Hello @${{ github.actor }}, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ [Tutorials](https://github.com/ultralytics/yolov5/wiki#tutorials) to get started, where you can find quickstart guides for simple tasks like [Custom Data Training](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data) all the way to advanced concepts like [Hyperparameter Evolution](https://github.com/ultralytics/yolov5/issues/607). + + If this is a 🐛 Bug Report, please provide screenshots and **minimum viable code to reproduce your issue**, otherwise we can not help you. + + If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online [W&B logging](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data#visualize) if available. + + For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com. + + ## Requirements + + Python 3.8 or later with all [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) dependencies installed, including `torch>=1.7`. To install run: + ```bash + $ pip install -r requirements.txt + ``` + + ## Environments + + YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including [CUDA](https://developer.nvidia.com/cuda)/[CUDNN](https://developer.nvidia.com/cudnn), [Python](https://www.python.org/) and [PyTorch](https://pytorch.org/) preinstalled): + + - **Google Colab and Kaggle** notebooks with free GPU: Open In Colab Open In Kaggle + - **Google Cloud** Deep Learning VM. See [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) + - **Amazon** Deep Learning AMI. See [AWS Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/AWS-Quickstart) + - **Docker Image**. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) Docker Pulls + + + ## Status + + ![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) + + If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/models/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit. + diff --git a/data_processing/yolov5_crowdhuman/.github/workflows/rebase.yml b/data_processing/yolov5_crowdhuman/.github/workflows/rebase.yml new file mode 100644 index 0000000..e86c577 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/workflows/rebase.yml @@ -0,0 +1,21 @@ +name: Automatic Rebase +# https://github.com/marketplace/actions/automatic-rebase + +on: + issue_comment: + types: [created] + +jobs: + rebase: + name: Rebase + if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/rebase') + runs-on: ubuntu-latest + steps: + - name: Checkout the latest code + uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Automatic Rebase + uses: cirrus-actions/rebase@1.3.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/data_processing/yolov5_crowdhuman/.github/workflows/stale.yml b/data_processing/yolov5_crowdhuman/.github/workflows/stale.yml new file mode 100644 index 0000000..0a094e2 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.github/workflows/stale.yml @@ -0,0 +1,18 @@ +name: Close stale issues +on: + schedule: + - cron: "0 0 * * *" + +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.' + stale-pr-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.' + days-before-stale: 30 + days-before-close: 5 + exempt-issue-labels: 'documentation,tutorial' + operations-per-run: 100 # The maximum number of operations per run, used to control rate limiting. diff --git a/data_processing/yolov5_crowdhuman/.gitignore b/data_processing/yolov5_crowdhuman/.gitignore new file mode 100644 index 0000000..91ce33f --- /dev/null +++ b/data_processing/yolov5_crowdhuman/.gitignore @@ -0,0 +1,252 @@ +# Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- +*.jpg +*.jpeg +*.png +*.bmp +*.tif +*.tiff +*.heic +*.JPG +*.JPEG +*.PNG +*.BMP +*.TIF +*.TIFF +*.HEIC +*.mp4 +*.mov +*.MOV +*.avi +*.data +*.json + +*.cfg +!cfg/yolov3*.cfg + +storage.googleapis.com +runs/* +data/* +!data/images/zidane.jpg +!data/images/bus.jpg +!data/coco.names +!data/coco_paper.names +!data/coco.data +!data/coco_*.data +!data/coco_*.txt +!data/trainvalno5k.shapes +!data/*.sh + +pycocotools/* +results*.txt +gcp_test*.sh + +# Datasets ------------------------------------------------------------------------------------------------------------- +coco/ +coco128/ +VOC/ + +# MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- +*.m~ +*.mat +!targets*.mat + +# Neural Network weights ----------------------------------------------------------------------------------------------- +*.weights +*.pt +*.onnx +*.mlmodel +*.torchscript +darknet53.conv.74 +yolov3-tiny.conv.15 + +# GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +wandb/ +.installed.cfg +*.egg + + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv* +venv*/ +ENV*/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + + +# https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- + +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon +Icon? + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + + +# https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/* +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries +.html # Bokeh Plots +.pg # TensorFlow Frozen Graphs +.avi # videos + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ +cmake-build-release/ + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties diff --git a/data_processing/yolov5_crowdhuman/Dockerfile b/data_processing/yolov5_crowdhuman/Dockerfile new file mode 100644 index 0000000..af8f7b4 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/Dockerfile @@ -0,0 +1,56 @@ +# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch +FROM nvcr.io/nvidia/pytorch:20.12-py3 + +# Install linux packages +RUN apt update && apt install -y zip screen libgl1-mesa-glx + +RUN apt-get install vim + +# Install python dependencies +RUN python -m pip install --upgrade pip +COPY requirements.txt . +RUN pip install -r requirements.txt gsutil + +# Create working directory +RUN mkdir -p /usr/src/app +WORKDIR /usr/src/app + +# Copy contents +COPY . /usr/src/app + +# Copy weights +#RUN python3 -c "from models import *; \ +#attempt_download('weights/yolov5s.pt'); \ +#attempt_download('weights/yolov5m.pt'); \ +#attempt_download('weights/yolov5l.pt')" + + +# --------------------------------------------------- Extras Below --------------------------------------------------- + +# Build and Push +# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t +# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done + +# Pull and Run +# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t + +# Pull and Run with local directory access +# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t + +# Kill all +# sudo docker kill $(sudo docker ps -q) + +# Kill all image-based +# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest) + +# Bash into running container +# sudo docker exec -it 5a9b5863d93d bash + +# Bash into stopped container +# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash + +# Send weights to GCP +# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt + +# Clean up +# docker system prune -a --volumes diff --git a/data_processing/yolov5_crowdhuman/LICENSE b/data_processing/yolov5_crowdhuman/LICENSE new file mode 100644 index 0000000..9e419e0 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/LICENSE @@ -0,0 +1,674 @@ +GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. \ No newline at end of file diff --git a/data_processing/yolov5_crowdhuman/README.md b/data_processing/yolov5_crowdhuman/README.md new file mode 100644 index 0000000..934e298 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/README.md @@ -0,0 +1,36 @@ +## Head & Person Detection Model + +### Download model trained on crowd human using yolov5(m) architeture +Download Link: [YOLOv5m-crowd-human](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing) + + +
+ +**Output (Crowd Human Model)** + +![image](https://drive.google.com/uc?export=view&id=1ZOhDBRXj-Ra0vPL7iG6lrxCWAFhJTAti) + +
+ + + +## Test + +```bash +$ python detect.py --weights crowdhuman_yolov5m.pt --source _test/ --view-img + +``` + + +## Test (Only Person Class) + +```bash +python3 detect.py --weights crowdhuman_yolov5m.pt --source _test/ --view-img --person +``` + + +## Test (Only Heads) + +```bash +python3 detect.py --weights crowdhuman_yolov5m.pt --source _test/ --view-img --heads +``` diff --git a/data_processing/yolov5_crowdhuman/detect.py b/data_processing/yolov5_crowdhuman/detect.py new file mode 100644 index 0000000..df8af24 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/detect.py @@ -0,0 +1,183 @@ +import argparse +import time +from pathlib import Path + +import cv2 +import torch +import torch.backends.cudnn as cudnn +from numpy import random + +from models.experimental import attempt_load +from utils.datasets import LoadStreams, LoadImages +from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ + scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path +from utils.plots import plot_one_box +from utils.torch_utils import select_device, load_classifier, time_synchronized + + +def detect(save_img=False): + source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size + webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( + ('rtsp://', 'rtmp://', 'http://')) + + # Directories + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Initialize + set_logging() + device = select_device(opt.device) + half = device.type != 'cpu' # half precision only supported on CUDA + + # Load model + model = attempt_load(weights, map_location=device) # load FP32 model + stride = int(model.stride.max()) # model stride + imgsz = check_img_size(imgsz, s=stride) # check img_size + if half: + model.half() # to FP16 + + # Second-stage classifier + classify = False + if classify: + modelc = load_classifier(name='resnet101', n=2) # initialize + modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval() + + # Set Dataloader + vid_path, vid_writer = None, None + if webcam: + view_img = check_imshow() + cudnn.benchmark = True # set True to speed up constant image size inference + dataset = LoadStreams(source, img_size=imgsz, stride=stride) + else: + save_img = True + dataset = LoadImages(source, img_size=imgsz, stride=stride) + + # Get names and colors + names = model.module.names if hasattr(model, 'module') else model.names + colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] + + # Run inference + if device.type != 'cpu': + model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once + t0 = time.time() + for path, img, im0s, vid_cap in dataset: + img = torch.from_numpy(img).to(device) + img = img.half() if half else img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + if img.ndimension() == 3: + img = img.unsqueeze(0) + + # Inference + t1 = time_synchronized() + pred = model(img, augment=opt.augment)[0] + + # Apply NMS + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) + t2 = time_synchronized() + + # Apply Classifier + if classify: + pred = apply_classifier(pred, modelc, img, im0s) + + # Process detections + for i, det in enumerate(pred): # detections per image + if webcam: # batch_size >= 1 + p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count + else: + p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) + + p = Path(p) # to Path + save_path = str(save_dir / p.name) # img.jpg + txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt + s += '%gx%g ' % img.shape[2:] # print string + gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + if len(det): + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() + + # Print results + for c in det[:, -1].unique(): + n = (det[:, -1] == c).sum() # detections per class + s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # Write results + for *xyxy, conf, cls in reversed(det): + if save_txt: # Write to file + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format + with open(txt_path + '.txt', 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + if save_img or view_img: # Add bbox to image + label = f'{names[int(cls)]} {conf:.2f}' + if opt.heads or opt.person: + if 'head' in label and opt.heads: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + if 'person' in label and opt.person: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + else: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + + # Print time (inference + NMS) + print(f'{s}Done. ({t2 - t1:.3f}s)') + + # Stream results + if view_img: + cv2.imshow(str(p), im0) + cv2.waitKey(0) # 1 millisecond + + # Save results (image with detections) + if save_img: + if dataset.mode == 'image': + cv2.imwrite(save_path, im0) + else: # 'video' + if vid_path != save_path: # new video + vid_path = save_path + if isinstance(vid_writer, cv2.VideoWriter): + vid_writer.release() # release previous video writer + + fourcc = 'mp4v' # output video codec + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) + vid_writer.write(im0) + + if save_txt or save_img: + s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + print(f"Results saved to {save_dir}{s}") + + print(f'Done. ({time.time() - t0:.3f}s)') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam + parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--view-img', action='store_true', help='display results') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') + parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--update', action='store_true', help='update all models') + parser.add_argument('--project', default='runs/detect', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--person', action='store_true', help='displays only person') + parser.add_argument('--heads', action='store_true', help='displays only person') + opt = parser.parse_args() + print(opt) + #check_requirements() + + with torch.no_grad(): + if opt.update: # update all models (to fix SourceChangeWarning) + for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: + detect() + strip_optimizer(opt.weights) + else: + detect() diff --git a/data_processing/yolov5_crowdhuman/detect_head_bbox.py b/data_processing/yolov5_crowdhuman/detect_head_bbox.py new file mode 100644 index 0000000..77971b0 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/detect_head_bbox.py @@ -0,0 +1,294 @@ +import argparse +import os.path +import time +from pathlib import Path + +import cv2 +import torch +import torch.backends.cudnn as cudnn +from numpy import random + +from models.experimental import attempt_load +from utils.datasets import LoadStreams, LoadImages +from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ + scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path +from utils.plots import plot_one_box +from utils.torch_utils import select_device, load_classifier, time_synchronized +import json +import numpy as np + +def detect(save_img=False): + source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size + + skeleton_path = os.path.join(opt.source,'2d_pose_result_hrnet.json') + source = os.path.join(opt.source,'images') + + with open(skeleton_path) as f: + pose2d_result = json.load(f) + + webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( + ('rtsp://', 'rtmp://', 'http://')) + + # Directories + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Initialize + set_logging() + device = select_device(opt.device) + half = device.type != 'cpu' # half precision only supported on CUDA + + # Load model + model = attempt_load(weights, map_location=device) # load FP32 model + stride = int(model.stride.max()) # model stride + imgsz = check_img_size(imgsz, s=stride) # check img_size + if half: + model.half() # to FP16 + + # Second-stage classifier + classify = False + if classify: + modelc = load_classifier(name='resnet101', n=2) # initialize + modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval() + + bbox_results = {} + result_json_path = os.path.join(opt.source, 'head_bbox_yolov5_crowdhuman.json') + print('result_json_path', result_json_path) + + if os.path.exists(result_json_path): + with open(result_json_path) as f: + bbox_results = json.load(f) + + + # Set Dataloader + vid_path, vid_writer = None, None + if webcam: + view_img = check_imshow() + cudnn.benchmark = True # set True to speed up constant image size inference + dataset = LoadStreams(source, img_size=imgsz, stride=stride) + else: + save_img = True + dataset = LoadImages(source,bbox_results, img_size=imgsz, stride=stride) + + # Get names and colors + names = model.module.names if hasattr(model, 'module') else model.names + colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] + + # Run inference + if device.type != 'cpu': + model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once + t0 = time.time() + + + + + + for path, img, im0s, vid_cap in dataset: + img_name = os.path.basename(path) + if img_name in bbox_results: + continue + + img = torch.from_numpy(img).to(device) + img = img.half() if half else img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + if img.ndimension() == 3: + img = img.unsqueeze(0) + + + coco_joint_list = pose2d_result[img_name] + bbox_list_wo_sort = [] + + + # Inference + t1 = time_synchronized() + pred = model(img, augment=opt.augment)[0] + + # Apply NMS + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) + t2 = time_synchronized() + + # Apply Classifier + if classify: + pred = apply_classifier(pred, modelc, img, im0s) + + # Process detections + for i, det in enumerate(pred): # detections per image + if webcam: # batch_size >= 1 + p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count + else: + p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) + + p = Path(p) # to Path + save_path = str(save_dir / p.name) # img.jpg + txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt + s += '%gx%g ' % img.shape[2:] # print string + gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + if len(det): + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() + + # Print results + for c in det[:, -1].unique(): + n = (det[:, -1] == c).sum() # detections per class + s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # Write results + for *xyxy, conf, cls in reversed(det): + if save_txt: # Write to file + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format + with open(txt_path + '.txt', 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + label = f'{names[int(cls)]} {conf:.2f}' + if 'head' in label: + bbox = [float(xyxy[0]), float(xyxy[1]), float(xyxy[2]-xyxy[0]), float(xyxy[3]-xyxy[1])] # x, y, w, h + #print(im0.shape) + bbox_list_wo_sort.append(bbox) + + + + + + if save_img or view_img: # Add bbox to image + label = f'{names[int(cls)]} {conf:.2f}' + if opt.heads or opt.person: + if 'head' in label and opt.heads: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + if 'person' in label and opt.person: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + else: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + + # Print time (inference + NMS) + print(f'{s}Done. ({t2 - t1:.3f}s)') + + # Stream results + if view_img: + print('resize to',(512, int(im0.shape[0]/im0.shape[1]*512))) + cv2.imshow(str(p), cv2.resize(im0, (512, int(im0.shape[0]/im0.shape[1]*512)))) + cv2.waitKey(0) # 1 millisecond + + # Save results (image with detections) + # if save_img: + # if dataset.mode == 'image': + # cv2.imwrite(save_path, im0) + # else: # 'video' + # if vid_path != save_path: # new video + # vid_path = save_path + # if isinstance(vid_writer, cv2.VideoWriter): + # vid_writer.release() # release previous video writer + # + # fourcc = 'mp4v' # output video codec + # fps = vid_cap.get(cv2.CAP_PROP_FPS) + # w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + # h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) + # vid_writer.write(im0) + + # sort bbox + bbox_list_sort = [] + + for idx in range(len(coco_joint_list)): + coco_joint_img = np.asarray(coco_joint_list[idx])[:, :3] + + face_points = coco_joint_img[:5, :3] + face_center = np.mean(face_points[:,:2], axis=0, keepdims=True) + #print('face_points', face_points) + # + clip_tresh = 0.5 + face_points_valid = face_points[face_points[:,2] > clip_tresh] + face_center_valid = np.mean(face_points_valid[:,:2], axis=0, keepdims=True) + #print('face_points_valid', face_points_valid) + # if valid face num >=1, match bbox to coco joint + if face_points_valid.shape[0] >= 1: + for bbox in bbox_list_wo_sort: + relax = 0.1 + relaxed_bbox = [bbox[0] - bbox[2] * relax, bbox[1] - bbox[3] * relax, bbox[2] * (1 + 2 * relax), + bbox[3] * (1 + 2 * relax)] + check = True + for points_idx in range(face_points.shape[0]): + if not (relaxed_bbox[0] <= face_points[points_idx][0] <= relaxed_bbox[0] + relaxed_bbox[2] and + relaxed_bbox[1] <= face_points[points_idx][1] <= relaxed_bbox[1] + relaxed_bbox[3]): + check = False + break + if check: + bbox_list_sort.append({'bbox':bbox,'score':1.0}) + break + else: + # if no valid face, use face center to match bbox (nearest ) + min_dist = 1e8 + min_bbox = None + for bbox in bbox_list_wo_sort: + bbox_c = [bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2] + if np.linalg.norm(bbox_c - face_center) < min_dist: + min_dist = np.linalg.norm(bbox_c - face_center) + min_bbox = bbox + + if min_bbox is not None: + bbox_list_sort.append({'bbox':min_bbox,'score':1.0}) + + + + + # no bbox detec, use coco joint to generate bbox + if len(bbox_list_sort) != idx+1: + # face_points_valid = face_points[face_points[:, 2] > clip_tresh] + # face_center_valid = np.mean(face_points_valid, axis=0, keepdims=True) + + if face_points_valid.shape[0] < 2: + bbox_list_sort.append({'bbox':[],'score':0.0}) + continue + + head_stride = max(np.max(face_points[:, 0]) - np.min(face_points[:, 0]), + np.max(face_points[:, 1]) - np.min(face_points[:, 1])) * 1.3 + temp_bbox = [face_center[0][0]-head_stride/2, face_center[0][1]-head_stride/2, head_stride, head_stride] + bbox_list_sort.append({'bbox':temp_bbox,'score':0.0}) + + if len(bbox_list_sort) != len(coco_joint_list): + raise ValueError('bbox_list_sort and coco_joint_list have different length') + + bbox_results[img_name] = bbox_list_sort + # save bbox + with open(result_json_path, 'w') as f: + json.dump(bbox_results, f) + + + # if save_txt or save_img: + # s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + # print(f"Results saved to {save_dir}{s}") + + print(f'Done. ({time.time() - t0:.3f}s)') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam + parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--view-img', action='store_true', help='display results') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') + parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--update', action='store_true', help='update all models') + parser.add_argument('--project', default='runs/detect', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--person', action='store_true', help='displays only person') + parser.add_argument('--heads', action='store_true', help='displays only person') + opt = parser.parse_args() + print(opt) + #check_requirements() + + with torch.no_grad(): + if opt.update: # update all models (to fix SourceChangeWarning) + for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: + detect() + strip_optimizer(opt.weights) + else: + detect() diff --git a/data_processing/yolov5_crowdhuman/hubconf.py b/data_processing/yolov5_crowdhuman/hubconf.py new file mode 100644 index 0000000..47eee44 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/hubconf.py @@ -0,0 +1,146 @@ +"""File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/ + +Usage: + import torch + model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80) +""" + +from pathlib import Path + +import torch + +from models.yolo import Model +from utils.general import set_logging +from utils.google_utils import attempt_download + +dependencies = ['torch', 'yaml'] +set_logging() + + +def create(name, pretrained, channels, classes, autoshape): + """Creates a specified YOLOv5 model + + Arguments: + name (str): name of model, i.e. 'yolov5s' + pretrained (bool): load pretrained weights into the model + channels (int): number of input channels + classes (int): number of model classes + + Returns: + pytorch model + """ + config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path + try: + model = Model(config, channels, classes) + if pretrained: + fname = f'{name}.pt' # checkpoint filename + attempt_download(fname) # download if not found locally + ckpt = torch.load(fname, map_location=torch.device('cpu')) # load + state_dict = ckpt['model'].float().state_dict() # to FP32 + state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter + model.load_state_dict(state_dict, strict=False) # load + if len(ckpt['model'].names) == classes: + model.names = ckpt['model'].names # set class names attribute + if autoshape: + model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS + return model + + except Exception as e: + help_url = 'https://github.com/ultralytics/yolov5/issues/36' + s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url + raise Exception(s) from e + + +def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True): + """YOLOv5-small model from https://github.com/ultralytics/yolov5 + + Arguments: + pretrained (bool): load pretrained weights into the model, default=False + channels (int): number of input channels, default=3 + classes (int): number of model classes, default=80 + + Returns: + pytorch model + """ + return create('yolov5s', pretrained, channels, classes, autoshape) + + +def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True): + """YOLOv5-medium model from https://github.com/ultralytics/yolov5 + + Arguments: + pretrained (bool): load pretrained weights into the model, default=False + channels (int): number of input channels, default=3 + classes (int): number of model classes, default=80 + + Returns: + pytorch model + """ + return create('yolov5m', pretrained, channels, classes, autoshape) + + +def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True): + """YOLOv5-large model from https://github.com/ultralytics/yolov5 + + Arguments: + pretrained (bool): load pretrained weights into the model, default=False + channels (int): number of input channels, default=3 + classes (int): number of model classes, default=80 + + Returns: + pytorch model + """ + return create('yolov5l', pretrained, channels, classes, autoshape) + + +def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True): + """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5 + + Arguments: + pretrained (bool): load pretrained weights into the model, default=False + channels (int): number of input channels, default=3 + classes (int): number of model classes, default=80 + + Returns: + pytorch model + """ + return create('yolov5x', pretrained, channels, classes, autoshape) + + +def custom(path_or_model='path/to/model.pt', autoshape=True): + """YOLOv5-custom model from https://github.com/ultralytics/yolov5 + + Arguments (3 options): + path_or_model (str): 'path/to/model.pt' + path_or_model (dict): torch.load('path/to/model.pt') + path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] + + Returns: + pytorch model + """ + model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint + if isinstance(model, dict): + model = model['model'] # load model + + hub_model = Model(model.yaml).to(next(model.parameters()).device) # create + hub_model.load_state_dict(model.float().state_dict()) # load state_dict + hub_model.names = model.names # class names + return hub_model.autoshape() if autoshape else hub_model + + +if __name__ == '__main__': + model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example + # model = custom(path_or_model='path/to/model.pt') # custom example + + # Verify inference + import numpy as np + from PIL import Image + + imgs = [Image.open('data/images/bus.jpg'), # PIL + 'data/images/zidane.jpg', # filename + 'https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg', # URI + np.zeros((640, 480, 3))] # numpy + + results = model(imgs) # batched inference + results.print() + results.save() diff --git a/data_processing/yolov5_crowdhuman/models/__init__.py b/data_processing/yolov5_crowdhuman/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/yolov5_crowdhuman/models/common.py b/data_processing/yolov5_crowdhuman/models/common.py new file mode 100644 index 0000000..ad35f90 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/common.py @@ -0,0 +1,308 @@ +# This file contains modules common to various models + +import math +from pathlib import Path + +import numpy as np +import requests +import torch +import torch.nn as nn +from PIL import Image + +from utils.datasets import letterbox +from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh +from utils.plots import color_list, plot_one_box + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super(Conv, self).__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super(Bottleneck, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(BottleneckCSP, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(C3, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super(SPP, self).__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super(Focus, self).__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + # self.contract = Contract(gain=2) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + # return self.conv(self.contract(x)) + + +class Contract(nn.Module): + # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain' + s = self.gain + x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) + return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40) + + +class Expand(nn.Module): + # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' + s = self.gain + x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) + return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super(Concat, self).__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def __init__(self): + super(NMS, self).__init__() + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class autoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super(autoShape, self).__init__() + self.model = model.eval() + + def autoshape(self): + print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # filename: imgs = 'data/samples/zidane.jpg' + # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1, files = [], [], [] # image and inference shapes, filenames + for i, im in enumerate(imgs): + if isinstance(im, str): # filename or uri + im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im # open + im.filename = f # for uri + files.append(Path(im.filename).with_suffix('.jpg').name if isinstance(im, Image.Image) else f'image{i}.jpg') + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = (size / max(s)) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, files, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, files, names=None): + super(Detections, self).__init__() + d = pred[0].device # device + gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.files = files # image filenames + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): + colors = color_list() + for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): + str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' + if pred is not None: + for c in pred[:, -1].unique(): + n = (pred[:, -1] == c).sum() # detections per class + str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string + if show or save or render: + for *box, conf, cls in pred: # xyxy, confidence, class + label = f'{self.names[int(cls)]} {conf:.2f}' + plot_one_box(box, img, label=label, color=colors[int(cls) % 10]) + img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np + if pprint: + print(str.rstrip(', ')) + if show: + img.show(self.files[i]) # show + if save: + f = Path(save_dir) / self.files[i] + img.save(f) # save + print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n') + if render: + self.imgs[i] = np.asarray(img) + + def print(self): + self.display(pprint=True) # print results + + def show(self): + self.display(show=True) # show results + + def save(self, save_dir='results/'): + Path(save_dir).mkdir(exist_ok=True) + self.display(save=True, save_dir=save_dir) # save results + + def render(self): + self.display(render=True) # render results + return self.imgs + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x + + +class Classify(nn.Module): + # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + super(Classify, self).__init__() + self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) + self.flat = nn.Flatten() + + def forward(self, x): + z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list + return self.flat(self.conv(z)) # flatten to x(b,c2) diff --git a/data_processing/yolov5_crowdhuman/models/experimental.py b/data_processing/yolov5_crowdhuman/models/experimental.py new file mode 100644 index 0000000..5fe5685 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/experimental.py @@ -0,0 +1,133 @@ +# This file contains experimental modules + +import numpy as np +import torch +import torch.nn as nn + +from models.common import Conv, DWConv +from utils.google_utils import attempt_download + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super(CrossConv, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class Sum(nn.Module): + # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 + def __init__(self, n, weight=False): # n: number of inputs + super(Sum, self).__init__() + self.weight = weight # apply weights boolean + self.iter = range(n - 1) # iter object + if weight: + self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights + + def forward(self, x): + y = x[0] # no weight + if self.weight: + w = torch.sigmoid(self.w) * 2 + for i in self.iter: + y = y + x[i + 1] * w[i] + else: + for i in self.iter: + y = y + x[i + 1] + return y + + +class GhostConv(nn.Module): + # Ghost Convolution https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups + super(GhostConv, self).__init__() + c_ = c2 // 2 # hidden channels + self.cv1 = Conv(c1, c_, k, s, None, g, act) + self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) + + def forward(self, x): + y = self.cv1(x) + return torch.cat([y, self.cv2(y)], 1) + + +class GhostBottleneck(nn.Module): + # Ghost Bottleneck https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride + super(GhostBottleneck, self).__init__() + c_ = c2 // 2 + self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw + DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw + GhostConv(c_, c2, 1, 1, act=False)) # pw-linear + self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), + Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + + def forward(self, x): + return self.conv(x) + self.shortcut(x) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super(MixConv2d, self).__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) + + +class Ensemble(nn.ModuleList): + # Ensemble of models + def __init__(self): + super(Ensemble, self).__init__() + + def forward(self, x, augment=False): + y = [] + for module in self: + y.append(module(x, augment)[0]) + # y = torch.stack(y).max(0)[0] # max ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 1) # nms ensemble + return y, None # inference, train output + + +def attempt_load(weights, map_location=None): + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + attempt_download(w) + model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model + + # Compatibility updates + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif type(m) is Conv: + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + if len(model) == 1: + return model[-1] # return model + else: + print('Ensemble created with %s\n' % weights) + for k in ['names', 'stride']: + setattr(model, k, getattr(model[-1], k)) + return model # return ensemble diff --git a/data_processing/yolov5_crowdhuman/models/export.py b/data_processing/yolov5_crowdhuman/models/export.py new file mode 100644 index 0000000..cc81787 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/export.py @@ -0,0 +1,100 @@ +"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats + +Usage: + $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1 +""" + +import argparse +import sys +import time + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +import torch +import torch.nn as nn + +import models +from models.experimental import attempt_load +from utils.activations import Hardswish, SiLU +from utils.general import set_logging, check_img_size + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') # from yolov5/models/ + parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width + parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') + parser.add_argument('--batch-size', type=int, default=1, help='batch size') + opt = parser.parse_args() + opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand + print(opt) + set_logging() + t = time.time() + + # Load PyTorch model + model = attempt_load(opt.weights, map_location=torch.device('cpu')) # load FP32 model + labels = model.names + + # Checks + gs = int(max(model.stride)) # grid size (max stride) + opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples + + # Input + img = torch.zeros(opt.batch_size, 3, *opt.img_size) # image size(1,3,320,192) iDetection + + # Update model + for k, m in model.named_modules(): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + if isinstance(m, models.common.Conv): # assign export-friendly activations + if isinstance(m.act, nn.Hardswish): + m.act = Hardswish() + elif isinstance(m.act, nn.SiLU): + m.act = SiLU() + # elif isinstance(m, models.yolo.Detect): + # m.forward = m.forward_export # assign forward (optional) + model.model[-1].export = True # set Detect() layer export=True + y = model(img) # dry run + + # TorchScript export + try: + print('\nStarting TorchScript export with torch %s...' % torch.__version__) + f = opt.weights.replace('.pt', '.torchscript.pt') # filename + ts = torch.jit.trace(model, img) + ts.save(f) + print('TorchScript export success, saved as %s' % f) + except Exception as e: + print('TorchScript export failure: %s' % e) + + # ONNX export + try: + import onnx + + print('\nStarting ONNX export with onnx %s...' % onnx.__version__) + f = opt.weights.replace('.pt', '.onnx') # filename + torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], + output_names=['classes', 'boxes'] if y is None else ['output'], + dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) + 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) + + # Checks + onnx_model = onnx.load(f) # load onnx model + onnx.checker.check_model(onnx_model) # check onnx model + # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model + print('ONNX export success, saved as %s' % f) + except Exception as e: + print('ONNX export failure: %s' % e) + + # CoreML export + try: + import coremltools as ct + + print('\nStarting CoreML export with coremltools %s...' % ct.__version__) + # convert model from torchscript and apply pixel scaling as per detect.py + model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) + f = opt.weights.replace('.pt', '.mlmodel') # filename + model.save(f) + print('CoreML export success, saved as %s' % f) + except Exception as e: + print('CoreML export failure: %s' % e) + + # Finish + print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) diff --git a/data_processing/yolov5_crowdhuman/models/hub/anchors.yaml b/data_processing/yolov5_crowdhuman/models/hub/anchors.yaml new file mode 100644 index 0000000..a07a4dc --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/anchors.yaml @@ -0,0 +1,58 @@ +# Default YOLOv5 anchors for COCO data + + +# P5 ------------------------------------------------------------------------------------------------------------------- +# P5-640: +anchors_p5_640: + - [ 10,13, 16,30, 33,23 ] # P3/8 + - [ 30,61, 62,45, 59,119 ] # P4/16 + - [ 116,90, 156,198, 373,326 ] # P5/32 + + +# P6 ------------------------------------------------------------------------------------------------------------------- +# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387 +anchors_p6_640: + - [ 9,11, 21,19, 17,41 ] # P3/8 + - [ 43,32, 39,70, 86,64 ] # P4/16 + - [ 65,131, 134,130, 120,265 ] # P5/32 + - [ 282,180, 247,354, 512,387 ] # P6/64 + +# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792 +anchors_p6_1280: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187 +anchors_p6_1920: + - [ 28,41, 67,59, 57,141 ] # P3/8 + - [ 144,103, 129,227, 270,205 ] # P4/16 + - [ 209,452, 455,396, 358,812 ] # P5/32 + - [ 653,922, 1109,570, 1387,1187 ] # P6/64 + + +# P7 ------------------------------------------------------------------------------------------------------------------- +# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372 +anchors_p7_640: + - [ 11,11, 13,30, 29,20 ] # P3/8 + - [ 30,46, 61,38, 39,92 ] # P4/16 + - [ 78,80, 146,66, 79,163 ] # P5/32 + - [ 149,150, 321,143, 157,303 ] # P6/64 + - [ 257,402, 359,290, 524,372 ] # P7/128 + +# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818 +anchors_p7_1280: + - [ 19,22, 54,36, 32,77 ] # P3/8 + - [ 70,83, 138,71, 75,173 ] # P4/16 + - [ 165,159, 148,334, 375,151 ] # P5/32 + - [ 334,317, 251,626, 499,474 ] # P6/64 + - [ 750,326, 534,814, 1079,818 ] # P7/128 + +# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227 +anchors_p7_1920: + - [ 29,34, 81,55, 47,115 ] # P3/8 + - [ 105,124, 207,107, 113,259 ] # P4/16 + - [ 247,238, 222,500, 563,227 ] # P5/32 + - [ 501,476, 376,939, 749,711 ] # P6/64 + - [ 1126,489, 801,1222, 1618,1227 ] # P7/128 diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov3-spp.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov3-spp.yaml new file mode 100644 index 0000000..38dcc44 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov3-spp.yaml @@ -0,0 +1,51 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# darknet53 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [32, 3, 1]], # 0 + [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 + [-1, 1, Bottleneck, [64]], + [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 + [-1, 2, Bottleneck, [128]], + [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 + [-1, 8, Bottleneck, [256]], + [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 + [-1, 8, Bottleneck, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 + [-1, 4, Bottleneck, [1024]], # 10 + ] + +# YOLOv3-SPP head +head: + [[-1, 1, Bottleneck, [1024, False]], + [-1, 1, SPP, [512, [5, 9, 13]]], + [-1, 1, Conv, [1024, 3, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) + + [-2, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) + + [-2, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P3 + [-1, 1, Bottleneck, [256, False]], + [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) + + [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov3-tiny.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov3-tiny.yaml new file mode 100644 index 0000000..ff7638c --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov3-tiny.yaml @@ -0,0 +1,41 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,14, 23,27, 37,58] # P4/16 + - [81,82, 135,169, 344,319] # P5/32 + +# YOLOv3-tiny backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [16, 3, 1]], # 0 + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 + [-1, 1, Conv, [32, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 + [-1, 1, Conv, [64, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 + [-1, 1, Conv, [128, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 + [-1, 1, Conv, [256, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 + [-1, 1, Conv, [512, 3, 1]], + [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11 + [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 + ] + +# YOLOv3-tiny head +head: + [[-1, 1, Conv, [1024, 3, 1]], + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) + + [-2, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) + + [[19, 15], 1, Detect, [nc, anchors]], # Detect(P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov3.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov3.yaml new file mode 100644 index 0000000..f2e7613 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov3.yaml @@ -0,0 +1,51 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# darknet53 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [32, 3, 1]], # 0 + [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 + [-1, 1, Bottleneck, [64]], + [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 + [-1, 2, Bottleneck, [128]], + [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 + [-1, 8, Bottleneck, [256]], + [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 + [-1, 8, Bottleneck, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 + [-1, 4, Bottleneck, [1024]], # 10 + ] + +# YOLOv3 head +head: + [[-1, 1, Bottleneck, [1024, False]], + [-1, 1, Conv, [512, [1, 1]]], + [-1, 1, Conv, [1024, 3, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) + + [-2, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) + + [-2, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P3 + [-1, 1, Bottleneck, [256, False]], + [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) + + [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5-fpn.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5-fpn.yaml new file mode 100644 index 0000000..e772bff --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5-fpn.yaml @@ -0,0 +1,42 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, Bottleneck, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, BottleneckCSP, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, BottleneckCSP, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 6, BottleneckCSP, [1024]], # 9 + ] + +# YOLOv5 FPN head +head: + [[-1, 3, BottleneckCSP, [1024, False]], # 10 (P5/32-large) + + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Conv, [512, 1, 1]], + [-1, 3, BottleneckCSP, [512, False]], # 14 (P4/16-medium) + + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 1, Conv, [256, 1, 1]], + [-1, 3, BottleneckCSP, [256, False]], # 18 (P3/8-small) + + [[18, 14, 10], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5-p2.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p2.yaml new file mode 100644 index 0000000..0633a90 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p2.yaml @@ -0,0 +1,54 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 7-P5/32 + [ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 9 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 13 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 17 (P3/8-small) + + [ -1, 1, Conv, [ 128, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 2 ], 1, Concat, [ 1 ] ], # cat backbone P2 + [ -1, 1, C3, [ 128, False ] ], # 21 (P2/4-xsmall) + + [ -1, 1, Conv, [ 128, 3, 2 ] ], + [ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P3 + [ -1, 3, C3, [ 256, False ] ], # 24 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 27 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 1024, False ] ], # 30 (P5/32-large) + + [ [ 24, 27, 30 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5-p6.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p6.yaml new file mode 100644 index 0000000..3728a11 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p6.yaml @@ -0,0 +1,56 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P5/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5-p7.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p7.yaml new file mode 100644 index 0000000..ca8f849 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5-p7.yaml @@ -0,0 +1,67 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 3, C3, [ 1024 ] ], + [ -1, 1, Conv, [ 1280, 3, 2 ] ], # 11-P7/128 + [ -1, 1, SPP, [ 1280, [ 3, 5 ] ] ], + [ -1, 3, C3, [ 1280, False ] ], # 13 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 1024, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat backbone P6 + [ -1, 3, C3, [ 1024, False ] ], # 17 + + [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 21 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 25 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 29 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 26 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 32 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 22 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 35 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 38 (P6/64-xlarge) + + [ -1, 1, Conv, [ 1024, 3, 2 ] ], + [ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P7 + [ -1, 3, C3, [ 1280, False ] ], # 41 (P7/128-xxlarge) + + [ [ 29, 32, 35, 38, 41 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6, P7) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5-panet.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5-panet.yaml new file mode 100644 index 0000000..340f95a --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5-panet.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, BottleneckCSP, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, BottleneckCSP, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, BottleneckCSP, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 + ] + +# YOLOv5 PANet head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, BottleneckCSP, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5l6.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5l6.yaml new file mode 100644 index 0000000..11298b0 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5l6.yaml @@ -0,0 +1,60 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P6/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5m6.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5m6.yaml new file mode 100644 index 0000000..48afc86 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5m6.yaml @@ -0,0 +1,60 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple + +# anchors +anchors: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P6/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5s6.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5s6.yaml new file mode 100644 index 0000000..1df577a --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5s6.yaml @@ -0,0 +1,60 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P6/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/data_processing/yolov5_crowdhuman/models/hub/yolov5x6.yaml b/data_processing/yolov5_crowdhuman/models/hub/yolov5x6.yaml new file mode 100644 index 0000000..5ebc021 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/hub/yolov5x6.yaml @@ -0,0 +1,60 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.33 # model depth multiple +width_multiple: 1.25 # layer channel multiple + +# anchors +anchors: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P6/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/data_processing/yolov5_crowdhuman/models/yolo.py b/data_processing/yolov5_crowdhuman/models/yolo.py new file mode 100644 index 0000000..4181709 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/yolo.py @@ -0,0 +1,272 @@ +import argparse +import logging +import sys +from copy import deepcopy + +sys.path.append('./') # to run '$ python *.py' files in subdirectories +logger = logging.getLogger(__name__) + +from models.common import * +from models.experimental import * +from utils.autoanchor import check_anchor_order +from utils.general import make_divisible, check_file, set_logging +from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ + select_device, copy_attr + +try: + import thop # for FLOPS computation +except ImportError: + thop = None + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super(Detect, self).__init__() + self.nc = nc # number of classes + self.no = nc + 5 # number of outputs per anchor + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer('anchors', a) # shape(nl,na,2) + self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + # x = x.copy() # for profiling + z = [] # inference output + self.training |= self.export + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = x[i].sigmoid() + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes + super(Model, self).__init__() + if isinstance(cfg, dict): + self.yaml = cfg # model dict + else: # is *.yaml + import yaml # for torch hub + self.yaml_file = Path(cfg).name + with open(cfg) as f: + self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict + + # Define model + ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels + if nc and nc != self.yaml['nc']: + logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) + self.yaml['nc'] = nc # override yaml value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml['nc'])] # default names + # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 256 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + # print('Strides: %s' % m.stride.tolist()) + + # Init weights, biases + initialize_weights(self) + self.info() + logger.info('') + + def forward(self, x, augment=False, profile=False): + if augment: + img_size = x.shape[-2:] # height, width + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + yi = self.forward_once(xi)[0] # forward + # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save + yi[..., :4] /= si # de-scale + if fi == 2: + yi[..., 1] = img_size[0] - 1 - yi[..., 1] # de-flip ud + elif fi == 3: + yi[..., 0] = img_size[1] - 1 - yi[..., 0] # de-flip lr + y.append(yi) + return torch.cat(y, 1), None # augmented inference, train + else: + return self.forward_once(x, profile) # single-scale inference, train + + def forward_once(self, x, profile=False): + y, dt = [], [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + if profile: + o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS + t = time_synchronized() + for _ in range(10): + _ = m(x) + dt.append((time_synchronized() - t) * 100) + print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type)) + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + if profile: + print('%.1fms total' % sum(dt)) + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + # def _print_weights(self): + # for m in self.model.modules(): + # if type(m) is Bottleneck: + # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print('Fusing layers... ') + for m in self.model.modules(): + if type(m) is Conv and hasattr(m, 'bn'): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, 'bn') # remove batchnorm + m.forward = m.fuseforward # update forward + self.info() + return self + + def nms(self, mode=True): # add or remove NMS module + present = type(self.model[-1]) is NMS # last layer is NMS + if mode and not present: + print('Adding NMS... ') + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name='%s' % m.i, module=m) # add + self.eval() + elif not mode and present: + print('Removing NMS... ') + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print('Adding autoShape... ') + m = autoShape(self) # wrap model + copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes + return m + + def info(self, verbose=False, img_size=640): # print model information + model_info(self, verbose, img_size) + + +def parse_model(d, ch): # model_dict, input_channels(3) + logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) + anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, + C3]: + c1, c2 = ch[f], args[0] + if c2 != no: # if not output + c2 = make_divisible(c2 * gw, 8) + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) # number of repeats + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum([ch[x] for x in f]) + elif m is Detect: + args.append([ch[x] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + elif m is Contract: + c2 = ch[f] * args[0] ** 2 + elif m is Expand: + c2 = ch[f] // args[0] ** 2 + else: + c2 = ch[f] + + m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace('__main__.', '') # module type + np = sum([x.numel() for x in m_.parameters()]) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + if i == 0: + ch = [] + ch.append(c2) + return nn.Sequential(*layers), sorted(save) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + opt = parser.parse_args() + opt.cfg = check_file(opt.cfg) # check file + set_logging() + device = select_device(opt.device) + + # Create model + model = Model(opt.cfg).to(device) + model.train() + + # Profile + # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) + # y = model(img, profile=True) + + # Tensorboard + # from torch.utils.tensorboard import SummaryWriter + # tb_writer = SummaryWriter() + # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/") + # tb_writer.add_graph(model.model, img) # add model to tensorboard + # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard diff --git a/data_processing/yolov5_crowdhuman/models/yolov5l.yaml b/data_processing/yolov5_crowdhuman/models/yolov5l.yaml new file mode 100644 index 0000000..71ebf86 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/yolov5l.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/yolov5m.yaml b/data_processing/yolov5_crowdhuman/models/yolov5m.yaml new file mode 100644 index 0000000..3c749c9 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/yolov5m.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/yolov5s.yaml b/data_processing/yolov5_crowdhuman/models/yolov5s.yaml new file mode 100644 index 0000000..aca669d --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/yolov5s.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/models/yolov5x.yaml b/data_processing/yolov5_crowdhuman/models/yolov5x.yaml new file mode 100644 index 0000000..d3babdf --- /dev/null +++ b/data_processing/yolov5_crowdhuman/models/yolov5x.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.33 # model depth multiple +width_multiple: 1.25 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/data_processing/yolov5_crowdhuman/my_detect.py b/data_processing/yolov5_crowdhuman/my_detect.py new file mode 100644 index 0000000..3e5e9d9 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/my_detect.py @@ -0,0 +1,253 @@ +import argparse +import os.path +import time +from pathlib import Path + +import cv2 +import torch +import torch.backends.cudnn as cudnn +from numpy import random + +from models.experimental import attempt_load +from utils.datasets import LoadStreams, LoadImages +from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ + scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path +from utils.plots import plot_one_box +from utils.torch_utils import select_device, load_classifier, time_synchronized +import json +import numpy as np + +def detect(save_img=False): + source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size + + skeleton_path = os.path.join(opt.source,'2d_pose_result_hrnet.json') + source = os.path.join(opt.source,'images') + + with open(skeleton_path) as f: + pose2d_result = json.load(f) + + webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( + ('rtsp://', 'rtmp://', 'http://')) + + # Directories + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Initialize + set_logging() + device = select_device(opt.device) + half = device.type != 'cpu' # half precision only supported on CUDA + + # Load model + model = attempt_load(weights, map_location=device) # load FP32 model + stride = int(model.stride.max()) # model stride + imgsz = check_img_size(imgsz, s=stride) # check img_size + if half: + model.half() # to FP16 + + # Second-stage classifier + classify = False + if classify: + modelc = load_classifier(name='resnet101', n=2) # initialize + modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval() + + # Set Dataloader + vid_path, vid_writer = None, None + if webcam: + view_img = check_imshow() + cudnn.benchmark = True # set True to speed up constant image size inference + dataset = LoadStreams(source, img_size=imgsz, stride=stride) + else: + save_img = True + dataset = LoadImages(source, img_size=imgsz, stride=stride) + + # Get names and colors + names = model.module.names if hasattr(model, 'module') else model.names + colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] + + # Run inference + if device.type != 'cpu': + model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once + t0 = time.time() + + + + bbox_results = {} + + + for path, img, im0s, vid_cap in dataset: + img = torch.from_numpy(img).to(device) + img = img.half() if half else img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + if img.ndimension() == 3: + img = img.unsqueeze(0) + + img_name = os.path.basename(path) + coco_joint_list = pose2d_result[img_name] + bbox_list_wo_sort = [] + + + # Inference + t1 = time_synchronized() + pred = model(img, augment=opt.augment)[0] + + # Apply NMS + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) + t2 = time_synchronized() + + # Apply Classifier + if classify: + pred = apply_classifier(pred, modelc, img, im0s) + + # Process detections + for i, det in enumerate(pred): # detections per image + if webcam: # batch_size >= 1 + p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count + else: + p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) + + p = Path(p) # to Path + save_path = str(save_dir / p.name) # img.jpg + txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt + s += '%gx%g ' % img.shape[2:] # print string + gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + if len(det): + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() + + # Print results + for c in det[:, -1].unique(): + n = (det[:, -1] == c).sum() # detections per class + s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # Write results + for *xyxy, conf, cls in reversed(det): + if save_txt: # Write to file + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format + with open(txt_path + '.txt', 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + label = f'{names[int(cls)]} {conf:.2f}' + if 'head' in label: + bbox = [float(xyxy[0]), float(xyxy[1]), float(xyxy[2]-xyxy[0]), float(xyxy[3]-xyxy[1])] # x, y, w, h + #print(im0.shape) + bbox_list_wo_sort.append(bbox) + + + + + + if save_img or view_img: # Add bbox to image + label = f'{names[int(cls)]} {conf:.2f}' + if opt.heads or opt.person: + if 'head' in label and opt.heads: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + if 'person' in label and opt.person: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + else: + plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + + # Print time (inference + NMS) + print(f'{s}Done. ({t2 - t1:.3f}s)') + + # Stream results + if view_img: + print('resize to',(512, int(im0.shape[0]/im0.shape[1]*512))) + cv2.imshow(str(p), cv2.resize(im0, (512, int(im0.shape[0]/im0.shape[1]*512)))) + cv2.waitKey(0) # 1 millisecond + + # Save results (image with detections) + if save_img: + if dataset.mode == 'image': + cv2.imwrite(save_path, im0) + else: # 'video' + if vid_path != save_path: # new video + vid_path = save_path + if isinstance(vid_writer, cv2.VideoWriter): + vid_writer.release() # release previous video writer + + fourcc = 'mp4v' # output video codec + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) + vid_writer.write(im0) + + # sort bbox + bbox_list_sort = [] + for idx in range(len(coco_joint_list)): + coco_joint_img = np.asarray(coco_joint_list[idx])[:, :3] + + face_points = coco_joint_img[:5, :3] + # sort face points + face_points = face_points[np.argsort(face_points[:, 2])] + #print('face_points', face_points.shape) + if np.sum(face_points[:, 2]) < 0.5: + face_points = face_points[-2:,:] + face_center = np.mean(face_points, axis=0, keepdims=True) + for bbox in bbox_list_wo_sort: + relax = 0.1 + relaxed_bbox = [bbox[0]-bbox[2]*relax, bbox[1]-bbox[3]*relax, bbox[2]*(1+2*relax), bbox[3]*(1+2*relax)] + check = True + for points_idx in range(face_points.shape[0]): + if not (relaxed_bbox[0] <= face_points[points_idx][0] <= relaxed_bbox[0]+relaxed_bbox[2] and relaxed_bbox[1] <= face_points[points_idx][1] <= relaxed_bbox[1]+relaxed_bbox[3]): + check = False + break + if check: + bbox_list_sort.append(bbox) + break + if len(bbox_list_sort) != idx+1: + head_stride = max(np.max(face_points[:, 0]) - np.min(face_points[:, 0]), + np.max(face_points[:, 1]) - np.min(face_points[:, 1])) * 1.2 + temp_bbox = [face_center[0][0]-head_stride/2, face_center[0][1]-head_stride/2, head_stride, head_stride] + bbox_list_sort.append(temp_bbox) + + if len(bbox_list_sort) != len(coco_joint_list): + raise ValueError('bbox_list_sort and coco_joint_list have different length') + + bbox_results[img_name] = bbox_list_sort + # save bbox + with open(os.path.join(opt.source, 'head_bbox_yolov5_crowdhuman.json'), 'w') as f: + json.dump(bbox_results, f) + + + + if save_txt or save_img: + s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + print(f"Results saved to {save_dir}{s}") + + print(f'Done. ({time.time() - t0:.3f}s)') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam + parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--view-img', action='store_true', help='display results') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') + parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--update', action='store_true', help='update all models') + parser.add_argument('--project', default='runs/detect', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--person', action='store_true', help='displays only person') + parser.add_argument('--heads', action='store_true', help='displays only person') + opt = parser.parse_args() + print(opt) + #check_requirements() + + with torch.no_grad(): + if opt.update: # update all models (to fix SourceChangeWarning) + for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: + detect() + strip_optimizer(opt.weights) + else: + detect() diff --git a/data_processing/yolov5_crowdhuman/requirements.txt b/data_processing/yolov5_crowdhuman/requirements.txt new file mode 100644 index 0000000..cb50cf8 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/requirements.txt @@ -0,0 +1,30 @@ +# pip install -r requirements.txt + +# base ---------------------------------------- +Cython +matplotlib>=3.2.2 +numpy>=1.18.5 +opencv-python>=4.1.2 +Pillow +PyYAML>=5.3.1 +scipy>=1.4.1 +tensorboard>=2.2 +torch>=1.7.0 +torchvision>=0.8.1 +tqdm>=4.41.0 + +# logging ------------------------------------- +# wandb + +# plotting ------------------------------------ +seaborn>=0.11.0 +pandas + +# export -------------------------------------- +# coremltools>=4.1 +# onnx>=1.8.1 +# scikit-learn==0.19.2 # for coreml quantization + +# extras -------------------------------------- +thop # FLOPS computation +pycocotools>=2.0 # COCO mAP diff --git a/data_processing/yolov5_crowdhuman/test.py b/data_processing/yolov5_crowdhuman/test.py new file mode 100644 index 0000000..ecd45f5 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/test.py @@ -0,0 +1,340 @@ +import argparse +import json +import os +from pathlib import Path +from threading import Thread + +import numpy as np +import torch +import yaml +from tqdm import tqdm + +from models.experimental import attempt_load +from utils.datasets import create_dataloader +from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ + box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr +from utils.metrics import ap_per_class, ConfusionMatrix +from utils.plots import plot_images, output_to_target, plot_study_txt +from utils.torch_utils import select_device, time_synchronized + + +def test(data, + weights=None, + batch_size=32, + imgsz=640, + conf_thres=0.001, + iou_thres=0.6, # for NMS + save_json=False, + single_cls=False, + augment=False, + verbose=False, + model=None, + dataloader=None, + save_dir=Path(''), # for saving images + save_txt=False, # for auto-labelling + save_hybrid=False, # for hybrid auto-labelling + save_conf=False, # save auto-label confidences + plots=True, + log_imgs=0, # number of logged images + compute_loss=None): + # Initialize/load model and set device + training = model is not None + if training: # called by train.py + device = next(model.parameters()).device # get model device + + else: # called directly + set_logging() + device = select_device(opt.device, batch_size=batch_size) + + # Directories + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Load model + model = attempt_load(weights, map_location=device) # load FP32 model + gs = max(int(model.stride.max()), 32) # grid size (max stride) + imgsz = check_img_size(imgsz, s=gs) # check img_size + + # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 + # if device.type != 'cpu' and torch.cuda.device_count() > 1: + # model = nn.DataParallel(model) + + # Half + half = device.type != 'cpu' # half precision only supported on CUDA + if half: + model.half() + + # Configure + model.eval() + is_coco = data.endswith('coco.yaml') # is COCO dataset + with open(data) as f: + data = yaml.load(f, Loader=yaml.SafeLoader) # model dict + check_dataset(data) # check + nc = 1 if single_cls else int(data['nc']) # number of classes + iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 + niou = iouv.numel() + + # Logging + log_imgs, wandb = min(log_imgs, 100), None # ceil + try: + import wandb # Weights & Biases + except ImportError: + log_imgs = 0 + + # Dataloader + if not training: + if device.type != 'cpu': + model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once + path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images + dataloader = create_dataloader(path, imgsz, batch_size, gs, opt, pad=0.5, rect=True, + prefix=colorstr('test: ' if opt.task == 'test' else 'val: '))[0] + + seen = 0 + confusion_matrix = ConfusionMatrix(nc=nc) + names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} + coco91class = coco80_to_coco91_class() + s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') + p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. + loss = torch.zeros(3, device=device) + jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] + for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): + img = img.to(device, non_blocking=True) + img = img.half() if half else img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + targets = targets.to(device) + nb, _, height, width = img.shape # batch size, channels, height, width + + with torch.no_grad(): + # Run model + t = time_synchronized() + out, train_out = model(img, augment=augment) # inference and training outputs + t0 += time_synchronized() - t + + # Compute loss + if compute_loss: + loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls + + # Run NMS + targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels + lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling + t = time_synchronized() + out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True) + t1 += time_synchronized() - t + + # Statistics per image + for si, pred in enumerate(out): + labels = targets[targets[:, 0] == si, 1:] + nl = len(labels) + tcls = labels[:, 0].tolist() if nl else [] # target class + path = Path(paths[si]) + seen += 1 + + if len(pred) == 0: + if nl: + stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) + continue + + # Predictions + predn = pred.clone() + scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred + + # Append to text file + if save_txt: + gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls in predn.tolist(): + xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + + # W&B logging + if plots and len(wandb_images) < log_imgs: + box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, + "class_id": int(cls), + "box_caption": "%s %.3f" % (names[cls], conf), + "scores": {"class_score": conf}, + "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] + boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space + wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name)) + + # Append to pycocotools JSON dictionary + if save_json: + # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... + image_id = int(path.stem) if path.stem.isnumeric() else path.stem + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(pred.tolist(), box.tolist()): + jdict.append({'image_id': image_id, + 'category_id': coco91class[int(p[5])] if is_coco else int(p[5]), + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5)}) + + # Assign all predictions as incorrect + correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device) + if nl: + detected = [] # target indices + tcls_tensor = labels[:, 0] + + # target boxes + tbox = xywh2xyxy(labels[:, 1:5]) + scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels + if plots: + confusion_matrix.process_batch(predn, torch.cat((labels[:, 0:1], tbox), 1)) + + # Per target class + for cls in torch.unique(tcls_tensor): + ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices + pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices + + # Search for detections + if pi.shape[0]: + # Prediction to target ious + ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices + + # Append detections + detected_set = set() + for j in (ious > iouv[0]).nonzero(as_tuple=False): + d = ti[i[j]] # detected target + if d.item() not in detected_set: + detected_set.add(d.item()) + detected.append(d) + correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn + if len(detected) == nl: # all targets already located in image + break + + # Append statistics (correct, conf, pcls, tcls) + stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) + + # Plot images + if plots and batch_i < 3: + f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels + Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() + f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions + Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start() + + # Compute statistics + stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy + if len(stats) and stats[0].any(): + p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) + ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 + mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() + nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class + else: + nt = torch.zeros(1) + + # Print results + pf = '%20s' + '%12.3g' * 6 # print format + print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) + + # Print results per class + if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): + for i, c in enumerate(ap_class): + print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) + + # Print speeds + t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple + if not training: + print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) + + # Plots + if plots: + confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) + if wandb and wandb.run: + val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))] + wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False) + + # Save JSON + if save_json and len(jdict): + w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights + anno_json = '../coco/annotations/instances_val2017.json' # annotations json + pred_json = str(save_dir / f"{w}_predictions.json") # predictions json + print('\nEvaluating pycocotools mAP... saving %s...' % pred_json) + with open(pred_json, 'w') as f: + json.dump(jdict, f) + + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + anno = COCO(anno_json) # init annotations api + pred = anno.loadRes(pred_json) # init predictions api + eval = COCOeval(anno, pred, 'bbox') + if is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate + eval.evaluate() + eval.accumulate() + eval.summarize() + map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) + except Exception as e: + print(f'pycocotools unable to run: {e}') + + # Return results + if not training: + s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + print(f"Results saved to {save_dir}{s}") + model.float() # for training + maps = np.zeros(nc) + map + for i, c in enumerate(ap_class): + maps[c] = ap[i] + return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='test.py') + parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') + parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') + parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS') + parser.add_argument('--task', default='val', help="'val', 'test', 'study'") + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') + parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--verbose', action='store_true', help='report mAP by class') + parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') + parser.add_argument('--project', default='runs/test', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + opt = parser.parse_args() + opt.save_json |= opt.data.endswith('coco.yaml') + opt.data = check_file(opt.data) # check file + print(opt) + check_requirements() + + if opt.task in ['val', 'test']: # run normally + test(opt.data, + opt.weights, + opt.batch_size, + opt.img_size, + opt.conf_thres, + opt.iou_thres, + opt.save_json, + opt.single_cls, + opt.augment, + opt.verbose, + save_txt=opt.save_txt | opt.save_hybrid, + save_hybrid=opt.save_hybrid, + save_conf=opt.save_conf, + ) + + elif opt.task == 'speed': # speed benchmarks + for w in opt.weights: + test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False) + + elif opt.task == 'study': # run over a range of settings and save/plot + x = list(range(256, 1536 + 128, 128)) # x axis (image sizes) + for w in opt.weights: + f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to + y = [] # y axis + for i in x: # img-size + print(f'\nRunning {f} point {i}...') + r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json, + plots=False) + y.append(r + t) # results and times + np.savetxt(f, y, fmt='%10.4g') # save + os.system('zip -r study.zip study_*.txt') + plot_study_txt(x=x) # plot diff --git a/data_processing/yolov5_crowdhuman/train.py b/data_processing/yolov5_crowdhuman/train.py new file mode 100644 index 0000000..e19cfa8 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/train.py @@ -0,0 +1,608 @@ +import argparse +import logging +import math +import os +import random +import time +from pathlib import Path +from threading import Thread + +import numpy as np +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +import torch.utils.data +import yaml +from torch.cuda import amp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +import test # import test.py to get mAP after each epoch +from models.experimental import attempt_load +from models.yolo import Model +from utils.autoanchor import check_anchors +from utils.datasets import create_dataloader +from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ + fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ + check_requirements, print_mutation, set_logging, one_cycle, colorstr +from utils.google_utils import attempt_download +from utils.loss import ComputeLoss +from utils.plots import plot_images, plot_labels, plot_results, plot_evolution +from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first + +logger = logging.getLogger(__name__) + + +def train(hyp, opt, device, tb_writer=None, wandb=None): + logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) + save_dir, epochs, batch_size, total_batch_size, weights, rank = \ + Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + + # Directories + wdir = save_dir / 'weights' + wdir.mkdir(parents=True, exist_ok=True) # make dir + last = wdir / 'last.pt' + best = wdir / 'best.pt' + results_file = save_dir / 'results.txt' + + # Save run settings + with open(save_dir / 'hyp.yaml', 'w') as f: + yaml.dump(hyp, f, sort_keys=False) + with open(save_dir / 'opt.yaml', 'w') as f: + yaml.dump(vars(opt), f, sort_keys=False) + + # Configure + plots = not opt.evolve # create plots + cuda = device.type != 'cpu' + init_seeds(2 + rank) + with open(opt.data) as f: + data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict + with torch_distributed_zero_first(rank): + check_dataset(data_dict) # check + train_path = data_dict['train'] + test_path = data_dict['val'] + nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes + names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names + assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check + + # Model + pretrained = weights.endswith('.pt') + if pretrained: + with torch_distributed_zero_first(rank): + attempt_download(weights) # download if not found locally + ckpt = torch.load(weights, map_location=device) # load checkpoint + if hyp.get('anchors'): + ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor + model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create + exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys + state_dict = ckpt['model'].float().state_dict() # to FP32 + state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect + model.load_state_dict(state_dict, strict=False) # load + logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report + else: + model = Model(opt.cfg, ch=3, nc=nc).to(device) # create + + # Freeze + freeze = [] # parameter names to freeze (full or partial) + for k, v in model.named_parameters(): + v.requires_grad = True # train all layers + if any(x in k for x in freeze): + print('freezing %s' % k) + v.requires_grad = False + + # Optimizer + nbs = 64 # nominal batch size + accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing + hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") + + pg0, pg1, pg2 = [], [], [] # optimizer parameter groups + for k, v in model.named_modules(): + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + pg2.append(v.bias) # biases + if isinstance(v, nn.BatchNorm2d): + pg0.append(v.weight) # no decay + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + pg1.append(v.weight) # apply decay + + if opt.adam: + optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum + else: + optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) + + optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay + optimizer.add_param_group({'params': pg2}) # add pg2 (biases) + logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) + del pg0, pg1, pg2 + + # Scheduler https://arxiv.org/pdf/1812.01187.pdf + # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR + if opt.linear_lr: + lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear + else: + lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) + # plot_lr_scheduler(optimizer, scheduler, epochs) + + # Logging + if rank in [-1, 0] and wandb and wandb.run is None: + opt.hyp = hyp # add hyperparameters + wandb_run = wandb.init(config=opt, resume="allow", + project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, + name=save_dir.stem, + id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) + loggers = {'wandb': wandb} # loggers dict + + # Resume + start_epoch, best_fitness = 0, 0.0 + if pretrained: + # Optimizer + if ckpt['optimizer'] is not None: + optimizer.load_state_dict(ckpt['optimizer']) + best_fitness = ckpt['best_fitness'] + + # Results + if ckpt.get('training_results') is not None: + with open(results_file, 'w') as file: + file.write(ckpt['training_results']) # write results.txt + + # Epochs + start_epoch = ckpt['epoch'] + 1 + if opt.resume: + assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) + if epochs < start_epoch: + logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % + (weights, ckpt['epoch'], epochs)) + epochs += ckpt['epoch'] # finetune additional epochs + + del ckpt, state_dict + + # Image sizes + gs = max(int(model.stride.max()), 32) # grid size (max stride) + nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) + imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples + + # DP mode + if cuda and rank == -1 and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + # SyncBatchNorm + if opt.sync_bn and cuda and rank != -1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) + logger.info('Using SyncBatchNorm()') + + # EMA + ema = ModelEMA(model) if rank in [-1, 0] else None + + # DDP mode + if cuda and rank != -1: + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) + + # Trainloader + dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, + world_size=opt.world_size, workers=opt.workers, + image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) + mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class + nb = len(dataloader) # number of batches + assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) + + # Process 0 + if rank in [-1, 0]: + ema.updates = start_epoch * nb // accumulate # set EMA updates + testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader + hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, + world_size=opt.world_size, workers=opt.workers, + pad=0.5, prefix=colorstr('val: '))[0] + + if not opt.resume: + labels = np.concatenate(dataset.labels, 0) + c = torch.tensor(labels[:, 0]) # classes + # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency + # model._initialize_biases(cf.to(device)) + if plots: + plot_labels(labels, save_dir, loggers) + if tb_writer: + tb_writer.add_histogram('classes', c, 0) + + # Anchors + if not opt.noautoanchor: + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) + + # Model parameters + hyp['box'] *= 3. / nl # scale to layers + hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers + hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers + model.nc = nc # attach number of classes to model + model.hyp = hyp # attach hyperparameters to model + model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights + model.names = names + + # Start training + t0 = time.time() + nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) + # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training + maps = np.zeros(nc) # mAP per class + results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) + scheduler.last_epoch = start_epoch - 1 # do not move + scaler = amp.GradScaler(enabled=cuda) + compute_loss = ComputeLoss(model) # init loss class + logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' + f'Using {dataloader.num_workers} dataloader workers\n' + f'Logging results to {save_dir}\n' + f'Starting training for {epochs} epochs...') + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ + model.train() + + # Update image weights (optional) + if opt.image_weights: + # Generate indices + if rank in [-1, 0]: + cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights + iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights + dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx + # Broadcast if DDP + if rank != -1: + indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int() + dist.broadcast(indices, 0) + if rank != 0: + dataset.indices = indices.cpu().numpy() + + # Update mosaic border + # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) + # dataset.mosaic_border = [b - imgsz, -b] # height, width borders + + mloss = torch.zeros(4, device=device) # mean losses + if rank != -1: + dataloader.sampler.set_epoch(epoch) + pbar = enumerate(dataloader) + logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size')) + if rank in [-1, 0]: + pbar = tqdm(pbar, total=nb) # progress bar + optimizer.zero_grad() + for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- + ni = i + nb * epoch # number integrated batches (since train start) + imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0 + + # Warmup + if ni <= nw: + xi = [0, nw] # x interp + # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) + accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) + for j, x in enumerate(optimizer.param_groups): + # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 + x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + if 'momentum' in x: + x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) + + # Multi-scale + if opt.multi_scale: + sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) + imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) + + # Forward + with amp.autocast(enabled=cuda): + pred = model(imgs) # forward + loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size + if rank != -1: + loss *= opt.world_size # gradient averaged between devices in DDP mode + if opt.quad: + loss *= 4. + + # Backward + scaler.scale(loss).backward() + + # Optimize + if ni % accumulate == 0: + scaler.step(optimizer) # optimizer.step + scaler.update() + optimizer.zero_grad() + if ema: + ema.update(model) + + # Print + if rank in [-1, 0]: + mloss = (mloss * i + loss_items) / (i + 1) # update mean losses + mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) + s = ('%10s' * 2 + '%10.4g' * 6) % ( + '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) + pbar.set_description(s) + + # Plot + if plots and ni < 3: + f = save_dir / f'train_batch{ni}.jpg' # filename + Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() + # if tb_writer: + # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) + # tb_writer.add_graph(model, imgs) # add model to tensorboard + elif plots and ni == 10 and wandb: + wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg') + if x.exists()]}, commit=False) + + # end batch ------------------------------------------------------------------------------------------------ + # end epoch ---------------------------------------------------------------------------------------------------- + + # Scheduler + lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard + scheduler.step() + + # DDP process 0 or single-GPU + if rank in [-1, 0]: + # mAP + if ema: + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) + final_epoch = epoch + 1 == epochs + if not opt.notest or final_epoch: # Calculate mAP + results, maps, times = test.test(opt.data, + batch_size=batch_size * 2, + imgsz=imgsz_test, + model=ema.ema, + single_cls=opt.single_cls, + dataloader=testloader, + save_dir=save_dir, + verbose=nc < 50 and final_epoch, + plots=plots and final_epoch, + log_imgs=opt.log_imgs if wandb else 0, + compute_loss=compute_loss) + + # Write + with open(results_file, 'a') as f: + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) + if len(opt.name) and opt.bucket: + os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) + + # Log + tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params + for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): + if tb_writer: + tb_writer.add_scalar(tag, x, epoch) # tensorboard + if wandb: + wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B + + # Update best mAP + fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] + if fi > best_fitness: + best_fitness = fi + + # Save model + save = (not opt.nosave) or (final_epoch and not opt.evolve) + if save: + with open(results_file, 'r') as f: # create checkpoint + ckpt = {'epoch': epoch, + 'best_fitness': best_fitness, + 'training_results': f.read(), + 'model': ema.ema, + 'optimizer': None if final_epoch else optimizer.state_dict(), + 'wandb_id': wandb_run.id if wandb else None} + + # Save last, best and delete + torch.save(ckpt, last) + if best_fitness == fi: + torch.save(ckpt, best) + del ckpt + # end epoch ---------------------------------------------------------------------------------------------------- + # end training + + if rank in [-1, 0]: + # Strip optimizers + final = best if best.exists() else last # final model + for f in [last, best]: + if f.exists(): + strip_optimizer(f) # strip optimizers + if opt.bucket: + os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload + + # Plots + if plots: + plot_results(save_dir=save_dir) # save as results.png + if wandb: + files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] + wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files + if (save_dir / f).exists()]}) + if opt.log_artifacts: + wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem) + + # Test best.pt + logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + if opt.data.endswith('coco.yaml') and nc == 80: # if COCO + for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests + results, _, _ = test.test(opt.data, + batch_size=batch_size * 2, + imgsz=imgsz_test, + conf_thres=conf, + iou_thres=iou, + model=attempt_load(final, device).half(), + single_cls=opt.single_cls, + dataloader=testloader, + save_dir=save_dir, + save_json=save_json, + plots=False) + + else: + dist.destroy_process_group() + + wandb.run.finish() if wandb and wandb.run else None + torch.cuda.empty_cache() + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') + parser.add_argument('--cfg', type=str, default='', help='model.yaml path') + parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') + parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path') + parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') + parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes') + parser.add_argument('--rect', action='store_true', help='rectangular training') + parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') + parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') + parser.add_argument('--notest', action='store_true', help='only test final epoch') + parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') + parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') + parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') + parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') + parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') + parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') + parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') + parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') + parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') + parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') + parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') + parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') + parser.add_argument('--project', default='runs/train', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--quad', action='store_true', help='quad dataloader') + parser.add_argument('--linear-lr', action='store_true', help='linear LR') + opt = parser.parse_args() + + # Set DDP variables + opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 + set_logging(opt.global_rank) + if opt.global_rank in [-1, 0]: + check_git_status() + check_requirements() + + # Resume + if opt.resume: # resume an interrupted run + ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path + assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' + apriori = opt.global_rank, opt.local_rank + with open(Path(ckpt).parent.parent / 'opt.yaml') as f: + opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace + opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate + logger.info('Resuming training from %s' % ckpt) + else: + # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') + opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files + assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' + opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) + opt.name = 'evolve' if opt.evolve else opt.name + opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run + + # DDP mode + opt.total_batch_size = opt.batch_size + device = select_device(opt.device, batch_size=opt.batch_size) + if opt.local_rank != -1: + assert torch.cuda.device_count() > opt.local_rank + torch.cuda.set_device(opt.local_rank) + device = torch.device('cuda', opt.local_rank) + dist.init_process_group(backend='nccl', init_method='env://') # distributed backend + assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' + opt.batch_size = opt.total_batch_size // opt.world_size + + # Hyperparameters + with open(opt.hyp) as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps + + # Train + logger.info(opt) + try: + import wandb + except ImportError: + wandb = None + prefix = colorstr('wandb: ') + logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") + if not opt.evolve: + tb_writer = None # init loggers + if opt.global_rank in [-1, 0]: + logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/') + tb_writer = SummaryWriter(opt.save_dir) # Tensorboard + train(hyp, opt, device, tb_writer, wandb) + + # Evolve hyperparameters (optional) + else: + # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) + meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) + 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1 + 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay + 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) + 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum + 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr + 'box': (1, 0.02, 0.2), # box loss gain + 'cls': (1, 0.2, 4.0), # cls loss gain + 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight + 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) + 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight + 'iou_t': (0, 0.1, 0.7), # IoU training threshold + 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold + 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore) + 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) + 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) + 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) + 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction) + 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg) + 'translate': (1, 0.0, 0.9), # image translation (+/- fraction) + 'scale': (1, 0.0, 0.9), # image scale (+/- gain) + 'shear': (1, 0.0, 10.0), # image shear (+/- deg) + 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) + 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) + 'mosaic': (1, 0.0, 1.0), # image mixup (probability) + 'mixup': (1, 0.0, 1.0)} # image mixup (probability) + + assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' + opt.notest, opt.nosave = True, True # only test/save final epoch + # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices + yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here + if opt.bucket: + os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists + + for _ in range(300): # generations to evolve + if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate + # Select parent(s) + parent = 'single' # parent selection method: 'single' or 'weighted' + x = np.loadtxt('evolve.txt', ndmin=2) + n = min(5, len(x)) # number of previous results to consider + x = x[np.argsort(-fitness(x))][:n] # top n mutations + w = fitness(x) - fitness(x).min() # weights + if parent == 'single' or len(x) == 1: + # x = x[random.randint(0, n - 1)] # random selection + x = x[random.choices(range(n), weights=w)[0]] # weighted selection + elif parent == 'weighted': + x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination + + # Mutate + mp, s = 0.8, 0.2 # mutation probability, sigma + npr = np.random + npr.seed(int(time.time())) + g = np.array([x[0] for x in meta.values()]) # gains 0-1 + ng = len(meta) + v = np.ones(ng) + while all(v == 1): # mutate until a change occurs (prevent duplicates) + v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) + for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) + hyp[k] = float(x[i + 7] * v[i]) # mutate + + # Constrain to limits + for k, v in meta.items(): + hyp[k] = max(hyp[k], v[1]) # lower limit + hyp[k] = min(hyp[k], v[2]) # upper limit + hyp[k] = round(hyp[k], 5) # significant digits + + # Train mutation + results = train(hyp.copy(), opt, device, wandb=wandb) + + # Write mutation results + print_mutation(hyp.copy(), results, yaml_file, opt.bucket) + + # Plot results + plot_evolution(yaml_file) + print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n' + f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') diff --git a/data_processing/yolov5_crowdhuman/tutorial.ipynb b/data_processing/yolov5_crowdhuman/tutorial.ipynb new file mode 100644 index 0000000..7fce40c --- /dev/null +++ b/data_processing/yolov5_crowdhuman/tutorial.ipynb @@ -0,0 +1,1252 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "YOLOv5 Tutorial", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true, + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "1f8e9b8ebded4175b2eaa9f75c3ceb00": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_0a1246a73077468ab80e979cc0576cd2", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_d327cde5a85a4a51bb8b1b3e9cf06c97", + "IPY_MODEL_d5ef1cb2cbed4b87b3c5d292ff2b0da6" + ] + } + }, + "0a1246a73077468ab80e979cc0576cd2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "d327cde5a85a4a51bb8b1b3e9cf06c97": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_8d5dff8bca14435a88fa1814533acd85", + "_dom_classes": [], + "description": "100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 819257867, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 819257867, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_3d5136c19e7645ca9bc8f51ceffb2be1" + } + }, + "d5ef1cb2cbed4b87b3c5d292ff2b0da6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_2919396dbd4b4c8e821d12bd28665d8a", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 781M/781M [00:12<00:00, 65.5MB/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_6feb16f2b2fa4021b1a271e1dd442d04" + } + }, + "8d5dff8bca14435a88fa1814533acd85": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "3d5136c19e7645ca9bc8f51ceffb2be1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "2919396dbd4b4c8e821d12bd28665d8a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "6feb16f2b2fa4021b1a271e1dd442d04": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "e6459e0bcee449b090fc9807672725bc": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_c341e1d3bf3b40d1821ce392eb966c68", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_660afee173694231a6dce3cd94df6cae", + "IPY_MODEL_261218485cef48df961519dde5edfcbe" + ] + } + }, + "c341e1d3bf3b40d1821ce392eb966c68": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "660afee173694231a6dce3cd94df6cae": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_32736d503c06497abfae8c0421918255", + "_dom_classes": [], + "description": "100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 22091032, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 22091032, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_e257738711f54d5280c8393d9d3dce1c" + } + }, + "261218485cef48df961519dde5edfcbe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_beb7a6fe34b840899bb79c062681696f", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 21.1M/21.1M [00:00<00:00, 33.5MB/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_e639132395d64d70b99d8b72c32f8fbb" + } + }, + "32736d503c06497abfae8c0421918255": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "e257738711f54d5280c8393d9d3dce1c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "beb7a6fe34b840899bb79c062681696f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "e639132395d64d70b99d8b72c32f8fbb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HvhYZrIZCEyo" + }, + "source": [ + "\n", + "\n", + "This notebook was written by Ultralytics LLC, and is freely available for redistribution under the [GPL-3.0 license](https://choosealicense.com/licenses/gpl-3.0/). \n", + "For more information please visit https://github.com/ultralytics/yolov5 and https://www.ultralytics.com." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7mGmQbAO5pQb" + }, + "source": [ + "# Setup\n", + "\n", + "Clone repo, install dependencies and check PyTorch and GPU." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wbvMlHd_QwMG", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ae8805a9-ce15-4e1c-f6b4-baa1c1033f56" + }, + "source": [ + "!git clone https://github.com/ultralytics/yolov5 # clone repo\n", + "%cd yolov5\n", + "%pip install -qr requirements.txt # install dependencies\n", + "\n", + "import torch\n", + "from IPython.display import Image, clear_output # to display images\n", + "\n", + "clear_output()\n", + "print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Setup complete. Using torch 1.7.0+cu101 _CudaDeviceProperties(name='Tesla V100-SXM2-16GB', major=7, minor=0, total_memory=16160MB, multi_processor_count=80)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4JnkELT0cIJg" + }, + "source": [ + "# 1. Inference\n", + "\n", + "`detect.py` runs inference on a variety of sources, downloading models automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zR9ZbuQCH7FX", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 534 + }, + "outputId": "c9a308f7-2216-4805-8003-eca8dd0dc30d" + }, + "source": [ + "!python detect.py --weights yolov5s.pt --img 640 --conf 0.25 --source data/images/\n", + "Image(filename='runs/detect/exp/zidane.jpg', width=600)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Namespace(agnostic_nms=False, augment=False, classes=None, conf_thres=0.25, device='', exist_ok=False, img_size=640, iou_thres=0.45, name='exp', project='runs/detect', save_conf=False, save_txt=False, source='data/images/', update=False, view_img=False, weights=['yolov5s.pt'])\n", + "YOLOv5 v4.0-21-gb26a2f6 torch 1.7.0+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16130.5MB)\n", + "\n", + "Fusing layers... \n", + "Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPS\n", + "image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 buss, 1 skateboards, Done. (0.011s)\n", + "image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.011s)\n", + "Results saved to runs/detect/exp\n", + "Done. (0.110s)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "image/jpeg": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [], + "image/jpeg": { + "width": 600 + } + }, + "execution_count": 38 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4qbaa3iEcrcE" + }, + "source": [ + "Results are saved to `runs/detect`. A full list of available inference sources:\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0eq1SMWl6Sfn" + }, + "source": [ + "# 2. Test\n", + "Test a model on [COCO](https://cocodataset.org/#home) val or test-dev dataset to evaluate trained accuracy. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases). To show results by class use the `--verbose` flag. Note that `pycocotools` metrics may be 1-2% better than the equivalent repo metrics, as is visible below, due to slight differences in mAP computation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eyTZYGgRjnMc" + }, + "source": [ + "## COCO val2017\n", + "Download [COCO val 2017](https://github.com/ultralytics/yolov5/blob/74b34872fdf41941cddcf243951cdb090fbac17b/data/coco.yaml#L14) dataset (1GB - 5000 images), and test model accuracy." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WQPtK1QYVaD_", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 65, + "referenced_widgets": [ + "1f8e9b8ebded4175b2eaa9f75c3ceb00", + "0a1246a73077468ab80e979cc0576cd2", + "d327cde5a85a4a51bb8b1b3e9cf06c97", + "d5ef1cb2cbed4b87b3c5d292ff2b0da6", + "8d5dff8bca14435a88fa1814533acd85", + "3d5136c19e7645ca9bc8f51ceffb2be1", + "2919396dbd4b4c8e821d12bd28665d8a", + "6feb16f2b2fa4021b1a271e1dd442d04" + ] + }, + "outputId": "d6ace7c6-1be5-41ff-d607-1c716b88d298" + }, + "source": [ + "# Download COCO val2017\n", + "torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017val.zip', 'tmp.zip')\n", + "!unzip -q tmp.zip -d ../ && rm tmp.zip" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1f8e9b8ebded4175b2eaa9f75c3ceb00", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=819257867.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "X58w8JLpMnjH", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "cc25f70c-0a11-44f6-cc44-e92c5083488c" + }, + "source": [ + "# Run YOLOv5x on COCO val2017\n", + "!python test.py --weights yolov5x.pt --data coco.yaml --img 640 --iou 0.65" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Namespace(augment=False, batch_size=32, conf_thres=0.001, data='./data/coco.yaml', device='', exist_ok=False, img_size=640, iou_thres=0.65, name='exp', project='runs/test', save_conf=False, save_hybrid=False, save_json=True, save_txt=False, single_cls=False, task='val', verbose=False, weights=['yolov5x.pt'])\n", + "YOLOv5 v4.0-75-gbdd88e1 torch 1.7.0+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n", + "\n", + "Downloading https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5x.pt to yolov5x.pt...\n", + "100% 168M/168M [00:04<00:00, 39.7MB/s]\n", + "\n", + "Fusing layers... \n", + "Model Summary: 476 layers, 87730285 parameters, 0 gradients, 218.8 GFLOPS\n", + "\u001b[34m\u001b[1mval: \u001b[0mScanning '../coco/val2017' for images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100% 5000/5000 [00:01<00:00, 2824.78it/s]\n", + "\u001b[34m\u001b[1mval: \u001b[0mNew cache created: ../coco/val2017.cache\n", + " Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 157/157 [01:33<00:00, 1.68it/s]\n", + " all 5e+03 3.63e+04 0.749 0.619 0.68 0.486\n", + "Speed: 5.2/2.0/7.3 ms inference/NMS/total per 640x640 image at batch-size 32\n", + "\n", + "Evaluating pycocotools mAP... saving runs/test/exp/yolov5x_predictions.json...\n", + "loading annotations into memory...\n", + "Done (t=0.44s)\n", + "creating index...\n", + "index created!\n", + "Loading and preparing results...\n", + "DONE (t=4.47s)\n", + "creating index...\n", + "index created!\n", + "Running per image evaluation...\n", + "Evaluate annotation type *bbox*\n", + "DONE (t=94.87s).\n", + "Accumulating evaluation results...\n", + "DONE (t=15.96s).\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.501\n", + " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.687\n", + " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.544\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.338\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.548\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.378\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.628\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.680\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.520\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.729\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.826\n", + "Results saved to runs/test/exp\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rc_KbFk0juX2" + }, + "source": [ + "## COCO test-dev2017\n", + "Download [COCO test2017](https://github.com/ultralytics/yolov5/blob/74b34872fdf41941cddcf243951cdb090fbac17b/data/coco.yaml#L15) dataset (7GB - 40,000 images), to test model accuracy on test-dev set (20,000 images). Results are saved to a `*.json` file which can be submitted to the evaluation server at https://competitions.codalab.org/competitions/20794." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "V0AJnSeCIHyJ" + }, + "source": [ + "# Download COCO test-dev2017\n", + "torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels.zip', 'tmp.zip')\n", + "!unzip -q tmp.zip -d ../ && rm tmp.zip # unzip labels\n", + "!f=\"test2017.zip\" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f && rm $f # 7GB, 41k images\n", + "%mv ./test2017 ./coco/images && mv ./coco ../ # move images to /coco and move /coco next to /yolov5" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "29GJXAP_lPrt" + }, + "source": [ + "# Run YOLOv5s on COCO test-dev2017 using --task test\n", + "!python test.py --weights yolov5s.pt --data coco.yaml --task test" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VUOiNLtMP5aG" + }, + "source": [ + "# 3. Train\n", + "\n", + "Download [COCO128](https://www.kaggle.com/ultralytics/coco128), a small 128-image tutorial dataset, start tensorboard and train YOLOv5s from a pretrained checkpoint for 3 epochs (note actual training is typically much longer, around **300-1000 epochs**, depending on your dataset)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Knxi2ncxWffW", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 65, + "referenced_widgets": [ + "e6459e0bcee449b090fc9807672725bc", + "c341e1d3bf3b40d1821ce392eb966c68", + "660afee173694231a6dce3cd94df6cae", + "261218485cef48df961519dde5edfcbe", + "32736d503c06497abfae8c0421918255", + "e257738711f54d5280c8393d9d3dce1c", + "beb7a6fe34b840899bb79c062681696f", + "e639132395d64d70b99d8b72c32f8fbb" + ] + }, + "outputId": "e8b7d5b3-a71e-4446-eec2-ad13419cf700" + }, + "source": [ + "# Download COCO128\n", + "torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', 'tmp.zip')\n", + "!unzip -q tmp.zip -d ../ && rm tmp.zip" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6459e0bcee449b090fc9807672725bc", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=22091032.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_pOkGLv1dMqh" + }, + "source": [ + "Train a YOLOv5s model on [COCO128](https://www.kaggle.com/ultralytics/coco128) with `--data coco128.yaml`, starting from pretrained `--weights yolov5s.pt`, or from randomly initialized `--weights '' --cfg yolov5s.yaml`. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases), and **COCO, COCO128, and VOC datasets are downloaded automatically** on first use.\n", + "\n", + "All training results are saved to `runs/train/` with incrementing run directories, i.e. `runs/train/exp2`, `runs/train/exp3` etc.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bOy5KI2ncnWd" + }, + "source": [ + "# Tensorboard (optional)\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir runs/train" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2fLAV42oNb7M" + }, + "source": [ + "# Weights & Biases (optional)\n", + "%pip install -q wandb \n", + "!wandb login # use 'wandb disabled' or 'wandb enabled' to disable or enable" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1NcFxRcFdJ_O", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "38e51b29-2df4-4f00-cde8-5f6e4a34da9e" + }, + "source": [ + "# Train YOLOv5s on COCO128 for 3 epochs\n", + "!python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --nosave --cache" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mgithub: \u001b[0mup to date with https://github.com/ultralytics/yolov5 ✅\n", + "YOLOv5 v4.0-75-gbdd88e1 torch 1.7.0+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n", + "\n", + "Namespace(adam=False, batch_size=16, bucket='', cache_images=True, cfg='', data='./data/coco128.yaml', device='', epochs=3, evolve=False, exist_ok=False, global_rank=-1, hyp='data/hyp.scratch.yaml', image_weights=False, img_size=[640, 640], linear_lr=False, local_rank=-1, log_artifacts=False, log_imgs=16, multi_scale=False, name='exp', noautoanchor=False, nosave=True, notest=False, project='runs/train', quad=False, rect=False, resume=False, save_dir='runs/train/exp', single_cls=False, sync_bn=False, total_batch_size=16, weights='yolov5s.pt', workers=8, world_size=1)\n", + "\u001b[34m\u001b[1mwandb: \u001b[0mInstall Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)\n", + "Start Tensorboard with \"tensorboard --logdir runs/train\", view at http://localhost:6006/\n", + "2021-02-12 06:38:28.027271: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1\n", + "\u001b[34m\u001b[1mhyperparameters: \u001b[0mlr0=0.01, lrf=0.2, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=0.05, cls=0.5, cls_pw=1.0, obj=1.0, obj_pw=1.0, iou_t=0.2, anchor_t=4.0, fl_gamma=0.0, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0\n", + "Downloading https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt to yolov5s.pt...\n", + "100% 14.1M/14.1M [00:01<00:00, 13.2MB/s]\n", + "\n", + "\n", + " from n params module arguments \n", + " 0 -1 1 3520 models.common.Focus [3, 32, 3] \n", + " 1 -1 1 18560 models.common.Conv [32, 64, 3, 2] \n", + " 2 -1 1 18816 models.common.C3 [64, 64, 1] \n", + " 3 -1 1 73984 models.common.Conv [64, 128, 3, 2] \n", + " 4 -1 1 156928 models.common.C3 [128, 128, 3] \n", + " 5 -1 1 295424 models.common.Conv [128, 256, 3, 2] \n", + " 6 -1 1 625152 models.common.C3 [256, 256, 3] \n", + " 7 -1 1 1180672 models.common.Conv [256, 512, 3, 2] \n", + " 8 -1 1 656896 models.common.SPP [512, 512, [5, 9, 13]] \n", + " 9 -1 1 1182720 models.common.C3 [512, 512, 1, False] \n", + " 10 -1 1 131584 models.common.Conv [512, 256, 1, 1] \n", + " 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 12 [-1, 6] 1 0 models.common.Concat [1] \n", + " 13 -1 1 361984 models.common.C3 [512, 256, 1, False] \n", + " 14 -1 1 33024 models.common.Conv [256, 128, 1, 1] \n", + " 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 16 [-1, 4] 1 0 models.common.Concat [1] \n", + " 17 -1 1 90880 models.common.C3 [256, 128, 1, False] \n", + " 18 -1 1 147712 models.common.Conv [128, 128, 3, 2] \n", + " 19 [-1, 14] 1 0 models.common.Concat [1] \n", + " 20 -1 1 296448 models.common.C3 [256, 256, 1, False] \n", + " 21 -1 1 590336 models.common.Conv [256, 256, 3, 2] \n", + " 22 [-1, 10] 1 0 models.common.Concat [1] \n", + " 23 -1 1 1182720 models.common.C3 [512, 512, 1, False] \n", + " 24 [17, 20, 23] 1 229245 models.yolo.Detect [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]\n", + "Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPS\n", + "\n", + "Transferred 362/362 items from yolov5s.pt\n", + "Scaled weight_decay = 0.0005\n", + "Optimizer groups: 62 .bias, 62 conv.weight, 59 other\n", + "\u001b[34m\u001b[1mtrain: \u001b[0mScanning '../coco128/labels/train2017' for images and labels... 128 found, 0 missing, 2 empty, 0 corrupted: 100% 128/128 [00:00<00:00, 2566.00it/s]\n", + "\u001b[34m\u001b[1mtrain: \u001b[0mNew cache created: ../coco128/labels/train2017.cache\n", + "\u001b[34m\u001b[1mtrain: \u001b[0mCaching images (0.1GB): 100% 128/128 [00:00<00:00, 175.07it/s]\n", + "\u001b[34m\u001b[1mval: \u001b[0mScanning '../coco128/labels/train2017.cache' for images and labels... 128 found, 0 missing, 2 empty, 0 corrupted: 100% 128/128 [00:00<00:00, 764773.38it/s]\n", + "\u001b[34m\u001b[1mval: \u001b[0mCaching images (0.1GB): 100% 128/128 [00:00<00:00, 128.17it/s]\n", + "Plotting labels... \n", + "\n", + "\u001b[34m\u001b[1mautoanchor: \u001b[0mAnalyzing anchors... anchors/target = 4.26, Best Possible Recall (BPR) = 0.9946\n", + "Image sizes 640 train, 640 test\n", + "Using 2 dataloader workers\n", + "Logging results to runs/train/exp\n", + "Starting training for 3 epochs...\n", + "\n", + " Epoch gpu_mem box obj cls total targets img_size\n", + " 0/2 3.27G 0.04357 0.06781 0.01869 0.1301 207 640: 100% 8/8 [00:03<00:00, 2.03it/s]\n", + " Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 4/4 [00:04<00:00, 1.14s/it]\n", + " all 128 929 0.646 0.627 0.659 0.431\n", + "\n", + " Epoch gpu_mem box obj cls total targets img_size\n", + " 1/2 7.75G 0.04308 0.06654 0.02083 0.1304 227 640: 100% 8/8 [00:01<00:00, 4.11it/s]\n", + " Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 4/4 [00:01<00:00, 2.94it/s]\n", + " all 128 929 0.681 0.607 0.663 0.434\n", + "\n", + " Epoch gpu_mem box obj cls total targets img_size\n", + " 2/2 7.75G 0.04461 0.06896 0.01866 0.1322 191 640: 100% 8/8 [00:02<00:00, 3.94it/s]\n", + " Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 4/4 [00:03<00:00, 1.22it/s]\n", + " all 128 929 0.642 0.632 0.662 0.432\n", + "Optimizer stripped from runs/train/exp/weights/last.pt, 14.8MB\n", + "3 epochs completed in 0.007 hours.\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "15glLzbQx5u0" + }, + "source": [ + "# 4. Visualize" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DLI1JmHU7B0l" + }, + "source": [ + "## Weights & Biases Logging 🌟 NEW\n", + "\n", + "[Weights & Biases](https://www.wandb.com/) (W&B) is now integrated with YOLOv5 for real-time visualization and cloud logging of training runs. This allows for better run comparison and introspection, as well improved visibility and collaboration for teams. To enable W&B `pip install wandb`, and then train normally (you will be guided through setup on first use). \n", + "\n", + "During training you will see live updates at [https://wandb.ai/home](https://wandb.ai/home), and you can create and share detailed [Reports](https://wandb.ai/glenn-jocher/yolov5_tutorial/reports/YOLOv5-COCO128-Tutorial-Results--VmlldzozMDI5OTY) of your results. For more information see the [YOLOv5 Weights & Biases Tutorial](https://github.com/ultralytics/yolov5/issues/1289). \n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-WPvRbS5Swl6" + }, + "source": [ + "## Local Logging\n", + "\n", + "All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp2`, `runs/train/exp3`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "riPdhraOTCO0" + }, + "source": [ + "Image(filename='runs/train/exp/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n", + "Image(filename='runs/train/exp/test_batch0_labels.jpg', width=800) # test batch 0 labels\n", + "Image(filename='runs/train/exp/test_batch0_pred.jpg', width=800) # test batch 0 predictions" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OYG4WFEnTVrI" + }, + "source": [ + "> \n", + "`train_batch0.jpg` shows train batch 0 mosaics and labels\n", + "\n", + "> \n", + "`test_batch0_labels.jpg` shows test batch 0 labels\n", + "\n", + "> \n", + "`test_batch0_pred.jpg` shows test batch 0 _predictions_\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7KN5ghjE6ZWh" + }, + "source": [ + "Training losses and performance metrics are also logged to [Tensorboard](https://www.tensorflow.org/tensorboard) and a custom `results.txt` logfile which is plotted as `results.png` (below) after training completes. Here we show YOLOv5s trained on COCO128 to 300 epochs, starting from scratch (blue), and from pretrained `--weights yolov5s.pt` (orange)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "MDznIqPF7nk3" + }, + "source": [ + "from utils.plots import plot_results \n", + "plot_results(save_dir='runs/train/exp') # plot all results*.txt as results.png\n", + "Image(filename='runs/train/exp/results.png', width=800)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lfrEegCSW3fK" + }, + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zelyeqbyt3GD" + }, + "source": [ + "# Environments\n", + "\n", + "YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including [CUDA](https://developer.nvidia.com/cuda)/[CUDNN](https://developer.nvidia.com/cudnn), [Python](https://www.python.org/) and [PyTorch](https://pytorch.org/) preinstalled):\n", + "\n", + "- **Google Colab and Kaggle** notebooks with free GPU: \"Open \"Open\n", + "- **Google Cloud** Deep Learning VM. See [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart)\n", + "- **Amazon** Deep Learning AMI. See [AWS Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/AWS-Quickstart)\n", + "- **Docker Image**. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) \"Docker\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Qu7Iesl0p54" + }, + "source": [ + "# Status\n", + "\n", + "![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n", + "\n", + "If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/models/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IEijrePND_2I" + }, + "source": [ + "# Appendix\n", + "\n", + "Optional extras below. Unit tests validate repo functionality and should be run on any PRs submitted.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "gI6NoBev8Ib1" + }, + "source": [ + "# Re-clone repo\n", + "%cd ..\n", + "%rm -rf yolov5 && git clone https://github.com/ultralytics/yolov5\n", + "%cd yolov5" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "mcKoSIK2WSzj" + }, + "source": [ + "# Reproduce\n", + "%%shell\n", + "for x in yolov5s yolov5m yolov5l yolov5x; do\n", + " python test.py --weights $x.pt --data coco.yaml --img 640 --conf 0.25 --iou 0.45 # speed\n", + " python test.py --weights $x.pt --data coco.yaml --img 640 --conf 0.001 --iou 0.65 # mAP\n", + "done" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "FGH0ZjkGjejy" + }, + "source": [ + "# Unit tests\n", + "%%shell\n", + "export PYTHONPATH=\"$PWD\" # to run *.py. files in subdirectories\n", + "\n", + "rm -rf runs # remove runs/\n", + "for m in yolov5s; do # models\n", + " python train.py --weights $m.pt --epochs 3 --img 320 --device 0 # train pretrained\n", + " python train.py --weights '' --cfg $m.yaml --epochs 3 --img 320 --device 0 # train scratch\n", + " for d in 0 cpu; do # devices\n", + " python detect.py --weights $m.pt --device $d # detect official\n", + " python detect.py --weights runs/train/exp/weights/best.pt --device $d # detect custom\n", + " python test.py --weights $m.pt --device $d # test official\n", + " python test.py --weights runs/train/exp/weights/best.pt --device $d # test custom\n", + " done\n", + " python hubconf.py # hub\n", + " python models/yolo.py --cfg $m.yaml # inspect\n", + " python models/export.py --weights $m.pt --img 640 --batch 1 # export\n", + "done" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "gogI-kwi3Tye" + }, + "source": [ + "# Profile\n", + "from utils.torch_utils import profile \n", + "\n", + "m1 = lambda x: x * torch.sigmoid(x)\n", + "m2 = torch.nn.SiLU()\n", + "profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RVRSOhEvUdb5" + }, + "source": [ + "# Evolve\n", + "!python train.py --img 640 --batch 64 --epochs 100 --data coco128.yaml --weights yolov5s.pt --cache --noautoanchor --evolve\n", + "!d=runs/train/evolve && cp evolve.* $d && zip -r evolve.zip $d && gsutil mv evolve.zip gs://bucket # upload results (optional)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BSgFCAcMbk1R" + }, + "source": [ + "# VOC\n", + "for b, m in zip([64, 48, 32, 16], ['yolov5s', 'yolov5m', 'yolov5l', 'yolov5x']): # zip(batch_size, model)\n", + " !python train.py --batch {b} --weights {m}.pt --data voc.yaml --epochs 50 --cache --img 512 --nosave --hyp hyp.finetune.yaml --project VOC --name {m}" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/data_processing/yolov5_crowdhuman/utils/__init__.py b/data_processing/yolov5_crowdhuman/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/yolov5_crowdhuman/utils/activations.py b/data_processing/yolov5_crowdhuman/utils/activations.py new file mode 100644 index 0000000..aa3ddf0 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/activations.py @@ -0,0 +1,72 @@ +# Activation functions + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- +class SiLU(nn.Module): # export-friendly version of nn.SiLU() + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + + +class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() + @staticmethod + def forward(x): + # return x * F.hardsigmoid(x) # for torchscript and CoreML + return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX + + +class MemoryEfficientSwish(nn.Module): + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * torch.sigmoid(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + sx = torch.sigmoid(x) + return grad_output * (sx * (1 + x * (1 - sx))) + + def forward(self, x): + return self.F.apply(x) + + +# Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- +class Mish(nn.Module): + @staticmethod + def forward(x): + return x * F.softplus(x).tanh() + + +class MemoryEfficientMish(nn.Module): + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + sx = torch.sigmoid(x) + fx = F.softplus(x).tanh() + return grad_output * (fx + x * sx * (1 - fx * fx)) + + def forward(self, x): + return self.F.apply(x) + + +# FReLU https://arxiv.org/abs/2007.11824 ------------------------------------------------------------------------------- +class FReLU(nn.Module): + def __init__(self, c1, k=3): # ch_in, kernel + super().__init__() + self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False) + self.bn = nn.BatchNorm2d(c1) + + def forward(self, x): + return torch.max(x, self.bn(self.conv(x))) diff --git a/data_processing/yolov5_crowdhuman/utils/autoanchor.py b/data_processing/yolov5_crowdhuman/utils/autoanchor.py new file mode 100644 index 0000000..5dba9f1 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/autoanchor.py @@ -0,0 +1,155 @@ +# Auto-anchor utils + +import numpy as np +import torch +import yaml +from scipy.cluster.vq import kmeans +from tqdm import tqdm + +from utils.general import colorstr + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print('Reversing anchor order') + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) + + +def check_anchors(dataset, model, thr=4.0, imgsz=640): + # Check anchor fit to data, recompute if necessary + prefix = colorstr('autoanchor: ') + print(f'\n{prefix}Analyzing anchors... ', end='') + m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() + shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) + scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale + wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh + + def metric(k): # compute metric + r = wh[:, None] / k[None] + x = torch.min(r, 1. / r).min(2)[0] # ratio metric + best = x.max(1)[0] # best_x + aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold + bpr = (best > 1. / thr).float().mean() # best possible recall + return bpr, aat + + bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) + print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') + if bpr < 0.98: # threshold to recompute + print('. Attempting to improve anchors, please wait...') + na = m.anchor_grid.numel() // 2 # number of anchors + new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) + new_bpr = metric(new_anchors.reshape(-1, 2))[0] + if new_bpr > bpr: # replace anchors + new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) + m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference + m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss + check_anchor_order(m) + print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') + else: + print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') + print('') # newline + + +def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): + """ Creates kmeans-evolved anchors from training dataset + + Arguments: + path: path to dataset *.yaml, or a loaded dataset + n: number of anchors + img_size: image size used for training + thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 + gen: generations to evolve anchors using genetic algorithm + verbose: print all results + + Return: + k: kmeans evolved anchors + + Usage: + from utils.autoanchor import *; _ = kmean_anchors() + """ + thr = 1. / thr + prefix = colorstr('autoanchor: ') + + def metric(k, wh): # compute metrics + r = wh[:, None] / k[None] + x = torch.min(r, 1. / r).min(2)[0] # ratio metric + # x = wh_iou(wh, torch.tensor(k)) # iou metric + return x, x.max(1)[0] # x, best_x + + def anchor_fitness(k): # mutation fitness + _, best = metric(torch.tensor(k, dtype=torch.float32), wh) + return (best * (best > thr).float()).mean() # fitness + + def print_results(k): + k = k[np.argsort(k.prod(1))] # sort small to large + x, best = metric(k, wh0) + bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr + print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') + print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' + f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') + for i, x in enumerate(k): + print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg + return k + + if isinstance(path, str): # *.yaml file + with open(path) as f: + data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict + from utils.datasets import LoadImagesAndLabels + dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) + else: + dataset = path # dataset + + # Get label wh + shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) + wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh + + # Filter + i = (wh0 < 3.0).any(1).sum() + if i: + print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') + wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels + # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 + + # Kmeans calculation + print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') + s = wh.std(0) # sigmas for whitening + k, dist = kmeans(wh / s, n, iter=30) # points, mean distance + k *= s + wh = torch.tensor(wh, dtype=torch.float32) # filtered + wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered + k = print_results(k) + + # Plot + # k, d = [None] * 20, [None] * 20 + # for i in tqdm(range(1, 21)): + # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance + # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) + # ax = ax.ravel() + # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') + # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh + # ax[0].hist(wh[wh[:, 0]<100, 0],400) + # ax[1].hist(wh[wh[:, 1]<100, 1],400) + # fig.savefig('wh.png', dpi=200) + + # Evolve + npr = np.random + f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma + pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar + for _ in pbar: + v = np.ones(sh) + while (v == 1).all(): # mutate until a change occurs (prevent duplicates) + v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) + kg = (k.copy() * v).clip(min=2.0) + fg = anchor_fitness(kg) + if fg > f: + f, k = fg, kg.copy() + pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' + if verbose: + print_results(k) + + return print_results(k) diff --git a/data_processing/yolov5_crowdhuman/utils/aws/__init__.py b/data_processing/yolov5_crowdhuman/utils/aws/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/yolov5_crowdhuman/utils/aws/mime.sh b/data_processing/yolov5_crowdhuman/utils/aws/mime.sh new file mode 100644 index 0000000..c319a83 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/aws/mime.sh @@ -0,0 +1,26 @@ +# AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/ +# This script will run on every instance restart, not only on first start +# --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA --- + +Content-Type: multipart/mixed; boundary="//" +MIME-Version: 1.0 + +--// +Content-Type: text/cloud-config; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Content-Disposition: attachment; filename="cloud-config.txt" + +#cloud-config +cloud_final_modules: +- [scripts-user, always] + +--// +Content-Type: text/x-shellscript; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Content-Disposition: attachment; filename="userdata.txt" + +#!/bin/bash +# --- paste contents of userdata.sh here --- +--// diff --git a/data_processing/yolov5_crowdhuman/utils/aws/resume.py b/data_processing/yolov5_crowdhuman/utils/aws/resume.py new file mode 100644 index 0000000..563f22b --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/aws/resume.py @@ -0,0 +1,37 @@ +# Resume all interrupted trainings in yolov5/ dir including DPP trainings +# Usage: $ python utils/aws/resume.py + +import os +import sys +from pathlib import Path + +import torch +import yaml + +sys.path.append('./') # to run '$ python *.py' files in subdirectories + +port = 0 # --master_port +path = Path('').resolve() +for last in path.rglob('*/**/last.pt'): + ckpt = torch.load(last) + if ckpt['optimizer'] is None: + continue + + # Load opt.yaml + with open(last.parent.parent / 'opt.yaml') as f: + opt = yaml.load(f, Loader=yaml.SafeLoader) + + # Get device count + d = opt['device'].split(',') # devices + nd = len(d) # number of devices + ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel + + if ddp: # multi-GPU + port += 1 + cmd = f'python -m torch.distributed.launch --nproc_per_node {nd} --master_port {port} train.py --resume {last}' + else: # single-GPU + cmd = f'python train.py --resume {last}' + + cmd += ' > /dev/null 2>&1 &' # redirect output to dev/null and run in daemon thread + print(cmd) + os.system(cmd) diff --git a/data_processing/yolov5_crowdhuman/utils/aws/userdata.sh b/data_processing/yolov5_crowdhuman/utils/aws/userdata.sh new file mode 100644 index 0000000..36405d1 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/aws/userdata.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html +# This script will run only once on first instance start (for a re-start script see mime.sh) +# /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir +# Use >300 GB SSD + +cd home/ubuntu +if [ ! -d yolov5 ]; then + echo "Running first-time script." # install dependencies, download COCO, pull Docker + git clone https://github.com/ultralytics/yolov5 && sudo chmod -R 777 yolov5 + cd yolov5 + bash data/scripts/get_coco.sh && echo "Data done." & + sudo docker pull ultralytics/yolov5:latest && echo "Docker done." & + # python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." & +else + echo "Running re-start script." # resume interrupted runs + i=0 + list=$(docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour' + while IFS= read -r id; do + ((i++)) + echo "restarting container $i: $id" + docker start $id + # docker exec -it $id python train.py --resume # single-GPU + docker exec -d $id python utils/aws/resume.py + done <<<"$list" +fi diff --git a/data_processing/yolov5_crowdhuman/utils/datasets.py b/data_processing/yolov5_crowdhuman/utils/datasets.py new file mode 100644 index 0000000..6ad6816 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/datasets.py @@ -0,0 +1,1059 @@ +# Dataset utils and dataloaders + +import glob +import logging +import math +import os +import random +import shutil +import time +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from threading import Thread + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image, ExifTags +from torch.utils.data import Dataset +from tqdm import tqdm + +from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, resample_segments, \ + clean_str +from utils.torch_utils import torch_distributed_zero_first + +# Parameters +help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' +img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp'] # acceptable image suffixes +vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes +logger = logging.getLogger(__name__) + +# Get orientation exif tag +for orientation in ExifTags.TAGS.keys(): + if ExifTags.TAGS[orientation] == 'Orientation': + break + + +def get_hash(files): + # Returns a single hash value of a list of files + return sum(os.path.getsize(f) for f in files if os.path.isfile(f)) + + +def exif_size(img): + # Returns exif-corrected PIL size + s = img.size # (width, height) + try: + rotation = dict(img._getexif().items())[orientation] + if rotation == 6: # rotation 270 + s = (s[1], s[0]) + elif rotation == 8: # rotation 90 + s = (s[1], s[0]) + except: + pass + + return s + + +def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, + rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): + # Make sure only the first process in DDP process the dataset first, and the following others can use the cache + with torch_distributed_zero_first(rank): + dataset = LoadImagesAndLabels(path, imgsz, batch_size, + augment=augment, # augment images + hyp=hyp, # augmentation hyperparameters + rect=rect, # rectangular training + cache_images=cache, + single_cls=opt.single_cls, + stride=int(stride), + pad=pad, + image_weights=image_weights, + prefix=prefix) + + batch_size = min(batch_size, len(dataset)) + nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers + sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None + loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader + # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() + dataloader = loader(dataset, + batch_size=batch_size, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) + return dataloader, dataset + + +class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): + """ Dataloader that reuses workers + + Uses same syntax as vanilla DataLoader + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + + +class LoadImages: # for inference + def __init__(self, path, bbox_results, img_size=640, stride=32): + p = str(Path(path).absolute()) # os-agnostic absolute path + if '*' in p: + files_ = sorted(glob.glob(p, recursive=True)) # glob + elif os.path.isdir(p): + files_ = sorted(glob.glob(os.path.join(p, '*.*'))) # dir + elif os.path.isfile(p): + files_ = [p] # files + else: + raise Exception(f'ERROR: {p} does not exist') + + files = [] + for f in files_: + if os.path.basename(f) in bbox_results: + continue + else: + files.append(f) + + + images = [x for x in files if x.split('.')[-1].lower() in img_formats] + videos = [x for x in files if x.split('.')[-1].lower() in vid_formats] + ni, nv = len(images), len(videos) + + self.img_size = img_size + self.stride = stride + self.files = images + videos + self.nf = ni + nv # number of files + self.video_flag = [False] * ni + [True] * nv + self.mode = 'image' + if any(videos): + self.new_video(videos[0]) # new video + else: + self.cap = None + assert self.nf > 0, f'No images or videos found in {p}. ' \ + f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}' + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count == self.nf: + raise StopIteration + path = self.files[self.count] + + if self.video_flag[self.count]: + # Read video + self.mode = 'video' + ret_val, img0 = self.cap.read() + if not ret_val: + self.count += 1 + self.cap.release() + if self.count == self.nf: # last video + raise StopIteration + else: + path = self.files[self.count] + self.new_video(path) + ret_val, img0 = self.cap.read() + + self.frame += 1 + print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='') + + else: + # Read image + self.count += 1 + img0 = cv2.imread(path) # BGR + assert img0 is not None, 'Image Not Found ' + path + print(f'image {self.count}/{self.nf} {path}: ', end='') + + # Padded resize + img = letterbox(img0, self.img_size, stride=self.stride)[0] + + # Convert + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + img = np.ascontiguousarray(img) + + return path, img, img0, self.cap + + def new_video(self, path): + self.frame = 0 + self.cap = cv2.VideoCapture(path) + self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + def __len__(self): + return self.nf # number of files + + +class LoadWebcam: # for inference + def __init__(self, pipe='0', img_size=640, stride=32): + self.img_size = img_size + self.stride = stride + + if pipe.isnumeric(): + pipe = eval(pipe) # local camera + # pipe = 'rtsp://192.168.1.64/1' # IP camera + # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login + # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera + + self.pipe = pipe + self.cap = cv2.VideoCapture(pipe) # video capture object + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size + + def __iter__(self): + self.count = -1 + return self + + def __next__(self): + self.count += 1 + if cv2.waitKey(1) == ord('q'): # q to quit + self.cap.release() + cv2.destroyAllWindows() + raise StopIteration + + # Read frame + if self.pipe == 0: # local camera + ret_val, img0 = self.cap.read() + img0 = cv2.flip(img0, 1) # flip left-right + else: # IP camera + n = 0 + while True: + n += 1 + self.cap.grab() + if n % 30 == 0: # skip frames + ret_val, img0 = self.cap.retrieve() + if ret_val: + break + + # Print + assert ret_val, f'Camera Error {self.pipe}' + img_path = 'webcam.jpg' + print(f'webcam {self.count}: ', end='') + + # Padded resize + img = letterbox(img0, self.img_size, stride=self.stride)[0] + + # Convert + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + img = np.ascontiguousarray(img) + + return img_path, img, img0, None + + def __len__(self): + return 0 + + +class LoadStreams: # multiple IP or RTSP cameras + def __init__(self, sources='streams.txt', img_size=640, stride=32): + self.mode = 'stream' + self.img_size = img_size + self.stride = stride + + if os.path.isfile(sources): + with open(sources, 'r') as f: + sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())] + else: + sources = [sources] + + n = len(sources) + self.imgs = [None] * n + self.sources = [clean_str(x) for x in sources] # clean source names for later + for i, s in enumerate(sources): + # Start the thread to read frames from the video stream + print(f'{i + 1}/{n}: {s}... ', end='') + cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) + assert cap.isOpened(), f'Failed to open {s}' + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) % 100 + _, self.imgs[i] = cap.read() # guarantee first frame + thread = Thread(target=self.update, args=([i, cap]), daemon=True) + print(f' success ({w}x{h} at {fps:.2f} FPS).') + thread.start() + print('') # newline + + # check for common shapes + s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes + self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal + if not self.rect: + print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') + + def update(self, index, cap): + # Read next stream frame in a daemon thread + n = 0 + while cap.isOpened(): + n += 1 + # _, self.imgs[index] = cap.read() + cap.grab() + if n == 4: # read every 4th frame + success, im = cap.retrieve() + self.imgs[index] = im if success else self.imgs[index] * 0 + n = 0 + time.sleep(0.01) # wait time + + def __iter__(self): + self.count = -1 + return self + + def __next__(self): + self.count += 1 + img0 = self.imgs.copy() + if cv2.waitKey(1) == ord('q'): # q to quit + cv2.destroyAllWindows() + raise StopIteration + + # Letterbox + img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0] + + # Stack + img = np.stack(img, 0) + + # Convert + img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416 + img = np.ascontiguousarray(img) + + return self.sources, img, img0, None + + def __len__(self): + return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years + + +def img2label_paths(img_paths): + # Define label paths as a function of image paths + sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings + return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths] + + +class LoadImagesAndLabels(Dataset): # for training/testing + def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, + cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): + self.img_size = img_size + self.augment = augment + self.hyp = hyp + self.image_weights = image_weights + self.rect = False if image_weights else rect + self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training) + self.mosaic_border = [-img_size // 2, -img_size // 2] + self.stride = stride + self.path = path + + try: + f = [] # image files + for p in path if isinstance(path, list) else [path]: + p = Path(p) # os-agnostic + if p.is_dir(): # dir + f += glob.glob(str(p / '**' / '*.*'), recursive=True) + # f = list(p.rglob('**/*.*')) # pathlib + elif p.is_file(): # file + with open(p, 'r') as t: + t = t.read().strip().splitlines() + parent = str(p.parent) + os.sep + f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path + # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) + else: + raise Exception(f'{prefix}{p} does not exist') + self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) + # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib + assert self.img_files, f'{prefix}No images found' + except Exception as e: + raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') + + # Check cache + self.label_files = img2label_paths(self.img_files) # labels + cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels + if cache_path.is_file(): + cache, exists = torch.load(cache_path), True # load + if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed + cache, exists = self.cache_labels(cache_path, prefix), False # re-cache + else: + cache, exists = self.cache_labels(cache_path, prefix), False # cache + + # Display cache + nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total + if exists: + d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" + tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results + assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' + + # Read cache + cache.pop('hash') # remove hash + cache.pop('version') # remove version + labels, shapes, self.segments = zip(*cache.values()) + self.labels = list(labels) + self.shapes = np.array(shapes, dtype=np.float64) + self.img_files = list(cache.keys()) # update + self.label_files = img2label_paths(cache.keys()) # update + if single_cls: + for x in self.labels: + x[:, 0] = 0 + + n = len(shapes) # number of images + bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index + nb = bi[-1] + 1 # number of batches + self.batch = bi # batch index of image + self.n = n + self.indices = range(n) + + # Rectangular Training + if self.rect: + # Sort by aspect ratio + s = self.shapes # wh + ar = s[:, 1] / s[:, 0] # aspect ratio + irect = ar.argsort() + self.img_files = [self.img_files[i] for i in irect] + self.label_files = [self.label_files[i] for i in irect] + self.labels = [self.labels[i] for i in irect] + self.shapes = s[irect] # wh + ar = ar[irect] + + # Set training image shapes + shapes = [[1, 1]] * nb + for i in range(nb): + ari = ar[bi == i] + mini, maxi = ari.min(), ari.max() + if maxi < 1: + shapes[i] = [maxi, 1] + elif mini > 1: + shapes[i] = [1, 1 / mini] + + self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride + + # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) + self.imgs = [None] * n + if cache_images: + gb = 0 # Gigabytes of cached images + self.img_hw0, self.img_hw = [None] * n, [None] * n + results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads + pbar = tqdm(enumerate(results), total=n) + for i, x in pbar: + self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) + gb += self.imgs[i].nbytes + pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' + + def cache_labels(self, path=Path('./labels.cache'), prefix=''): + # Cache dataset labels, check images and read shapes + x = {} # dict + nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate + pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) + for i, (im_file, lb_file) in enumerate(pbar): + try: + # verify images + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + segments = [] # instance segments + assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' + assert im.format.lower() in img_formats, f'invalid image format {im.format}' + + # verify labels + if os.path.isfile(lb_file): + nf += 1 # label found + with open(lb_file, 'r') as f: + l = [x.split() for x in f.read().strip().splitlines()] + if any([len(x) > 8 for x in l]): # is segment + classes = np.array([x[0] for x in l], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) + l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + l = np.array(l, dtype=np.float32) + if len(l): + assert l.shape[1] == 5, 'labels require 5 columns each' + assert (l >= 0).all(), 'negative labels' + assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' + assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' + else: + ne += 1 # label empty + l = np.zeros((0, 5), dtype=np.float32) + else: + nm += 1 # label missing + l = np.zeros((0, 5), dtype=np.float32) + x[im_file] = [l, shape, segments] + except Exception as e: + nc += 1 + print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') + + pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ + f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" + + if nf == 0: + print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') + + x['hash'] = get_hash(self.label_files + self.img_files) + x['results'] = nf, nm, ne, nc, i + 1 + x['version'] = 0.1 # cache version + torch.save(x, path) # save for next time + logging.info(f'{prefix}New cache created: {path}') + return x + + def __len__(self): + return len(self.img_files) + + # def __iter__(self): + # self.count = -1 + # print('ran dataset iter') + # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) + # return self + + def __getitem__(self, index): + index = self.indices[index] # linear, shuffled, or image_weights + + hyp = self.hyp + mosaic = self.mosaic and random.random() < hyp['mosaic'] + if mosaic: + # Load mosaic + img, labels = load_mosaic(self, index) + shapes = None + + # MixUp https://arxiv.org/pdf/1710.09412.pdf + if random.random() < hyp['mixup']: + img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1)) + r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0 + img = (img * r + img2 * (1 - r)).astype(np.uint8) + labels = np.concatenate((labels, labels2), 0) + + else: + # Load image + img, (h0, w0), (h, w) = load_image(self, index) + + # Letterbox + shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape + img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) + shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling + + labels = self.labels[index].copy() + if labels.size: # normalized xywh to pixel xyxy format + labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) + + if self.augment: + # Augment imagespace + if not mosaic: + img, labels = random_perspective(img, labels, + degrees=hyp['degrees'], + translate=hyp['translate'], + scale=hyp['scale'], + shear=hyp['shear'], + perspective=hyp['perspective']) + + # Augment colorspace + augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) + + # Apply cutouts + # if random.random() < 0.9: + # labels = cutout(img, labels) + + nL = len(labels) # number of labels + if nL: + labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh + labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1 + labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1 + + if self.augment: + # flip up-down + if random.random() < hyp['flipud']: + img = np.flipud(img) + if nL: + labels[:, 2] = 1 - labels[:, 2] + + # flip left-right + if random.random() < hyp['fliplr']: + img = np.fliplr(img) + if nL: + labels[:, 1] = 1 - labels[:, 1] + + labels_out = torch.zeros((nL, 6)) + if nL: + labels_out[:, 1:] = torch.from_numpy(labels) + + # Convert + img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + img = np.ascontiguousarray(img) + + return torch.from_numpy(img), labels_out, self.img_files[index], shapes + + @staticmethod + def collate_fn(batch): + img, label, path, shapes = zip(*batch) # transposed + for i, l in enumerate(label): + l[:, 0] = i # add target image index for build_targets() + return torch.stack(img, 0), torch.cat(label, 0), path, shapes + + @staticmethod + def collate_fn4(batch): + img, label, path, shapes = zip(*batch) # transposed + n = len(shapes) // 4 + img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] + + ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) + wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) + s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale + for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW + i *= 4 + if random.random() < 0.5: + im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ + 0].type(img[i].type()) + l = label[i] + else: + im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) + l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s + img4.append(im) + label4.append(l) + + for i, l in enumerate(label4): + l[:, 0] = i # add target image index for build_targets() + + return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 + + +# Ancillary functions -------------------------------------------------------------------------------------------------- +def load_image(self, index): + # loads 1 image from dataset, returns img, original hw, resized hw + img = self.imgs[index] + if img is None: # not cached + path = self.img_files[index] + img = cv2.imread(path) # BGR + assert img is not None, 'Image Not Found ' + path + h0, w0 = img.shape[:2] # orig hw + r = self.img_size / max(h0, w0) # resize image to img_size + if r != 1: # always resize down, only resize up if training with augmentation + interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) + return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized + else: + return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized + + +def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): + r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains + hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + dtype = img.dtype # uint8 + + x = np.arange(0, 256, dtype=np.int16) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype) + cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed + + +def hist_equalize(img, clahe=True, bgr=False): + # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255 + yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) + if clahe: + c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + yuv[:, :, 0] = c.apply(yuv[:, :, 0]) + else: + yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram + return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB + + +def load_mosaic(self, index): + # loads images in a 4-mosaic + + labels4, segments4 = [], [] + s = self.img_size + yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y + indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices + for i, index in enumerate(indices): + # Load image + img, _, (h, w) = load_image(self, index) + + # place img in img4 + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + padw = x1a - x1b + padh = y1a - y1b + + # Labels + labels, segments = self.labels[index].copy(), self.segments[index].copy() + if labels.size: + labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format + segments = [xyn2xy(x, w, h, padw, padh) for x in segments] + labels4.append(labels) + segments4.extend(segments) + + # Concat/clip labels + labels4 = np.concatenate(labels4, 0) + for x in (labels4[:, 1:], *segments4): + np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() + # img4, labels4 = replicate(img4, labels4) # replicate + + # Augment + img4, labels4 = random_perspective(img4, labels4, segments4, + degrees=self.hyp['degrees'], + translate=self.hyp['translate'], + scale=self.hyp['scale'], + shear=self.hyp['shear'], + perspective=self.hyp['perspective'], + border=self.mosaic_border) # border to remove + + return img4, labels4 + + +def load_mosaic9(self, index): + # loads images in a 9-mosaic + + labels9, segments9 = [], [] + s = self.img_size + indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices + for i, index in enumerate(indices): + # Load image + img, _, (h, w) = load_image(self, index) + + # place img in img9 + if i == 0: # center + img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # top + c = s, s - h, s + w, s + elif i == 2: # top right + c = s + wp, s - h, s + wp + w, s + elif i == 3: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 4: # bottom right + c = s + w0, s + hp, s + w0 + w, s + hp + h + elif i == 5: # bottom + c = s + w0 - w, s + h0, s + w0, s + h0 + h + elif i == 6: # bottom left + c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h + elif i == 7: # left + c = s - w, s + h0 - h, s, s + h0 + elif i == 8: # top left + c = s - w, s + h0 - hp - h, s, s + h0 - hp + + padx, pady = c[:2] + x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords + + # Labels + labels, segments = self.labels[index].copy(), self.segments[index].copy() + if labels.size: + labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format + segments = [xyn2xy(x, w, h, padx, pady) for x in segments] + labels9.append(labels) + segments9.extend(segments) + + # Image + img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] + hp, wp = h, w # height, width previous + + # Offset + yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y + img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] + + # Concat/clip labels + labels9 = np.concatenate(labels9, 0) + labels9[:, [1, 3]] -= xc + labels9[:, [2, 4]] -= yc + c = np.array([xc, yc]) # centers + segments9 = [x - c for x in segments9] + + for x in (labels9[:, 1:], *segments9): + np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() + # img9, labels9 = replicate(img9, labels9) # replicate + + # Augment + img9, labels9 = random_perspective(img9, labels9, segments9, + degrees=self.hyp['degrees'], + translate=self.hyp['translate'], + scale=self.hyp['scale'], + shear=self.hyp['shear'], + perspective=self.hyp['perspective'], + border=self.mosaic_border) # border to remove + + return img9, labels9 + + +def replicate(img, labels): + # Replicate labels + h, w = img.shape[:2] + boxes = labels[:, 1:].astype(int) + x1, y1, x2, y2 = boxes.T + s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels) + for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices + x1b, y1b, x2b, y2b = boxes[i] + bh, bw = y2b - y1b, x2b - x1b + yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y + x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] + img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) + + return img, labels + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) + + +def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, + border=(0, 0)): + # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) + # targets = [cls, xyxy] + + height = img.shape[0] + border[0] * 2 # shape(h,w,c) + width = img.shape[1] + border[1] * 2 + + # Center + C = np.eye(3) + C[0, 2] = -img.shape[1] / 2 # x translation (pixels) + C[1, 2] = -img.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3) + P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) + P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3) + a = random.uniform(-degrees, degrees) + # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations + s = random.uniform(1 - scale, 1 + scale) + # s = 2 ** random.uniform(-scale, scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3) + T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) + T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if perspective: + img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114)) + else: # affine + img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) + + # Visualize + # import matplotlib.pyplot as plt + # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() + # ax[0].imshow(img[:, :, ::-1]) # base + # ax[1].imshow(img2[:, :, ::-1]) # warped + + # Transform label coordinates + n = len(targets) + if n: + use_segments = any(x.any() for x in segments) + new = np.zeros((n, 4)) + if use_segments: # warp segments + segments = resample_segments(segments) # upsample + for i, segment in enumerate(segments): + xy = np.ones((len(segment), 3)) + xy[:, :2] = segment + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine + + # clip + new[i] = segment2box(xy, width, height) + + else: # warp boxes + xy = np.ones((n * 4, 3)) + xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine + + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + + # clip + new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) + new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) + + # filter candidates + i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) + targets = targets[i] + targets[:, 1:5] = new[i] + + return img, targets + + +def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) + # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates + + +def cutout(image, labels): + # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 + h, w = image.shape[:2] + + def bbox_ioa(box1, box2): + # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 + box2 = box2.transpose() + + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + + # Intersection area + inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \ + (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0) + + # box2 area + box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 + + # Intersection over box2 area + return inter_area / box2_area + + # create random masks + scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction + for s in scales: + mask_h = random.randint(1, int(h * s)) + mask_w = random.randint(1, int(w * s)) + + # box + xmin = max(0, random.randint(0, w) - mask_w // 2) + ymin = max(0, random.randint(0, h) - mask_h // 2) + xmax = min(w, xmin + mask_w) + ymax = min(h, ymin + mask_h) + + # apply random color mask + image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] + + # return unobscured labels + if len(labels) and s > 0.03: + box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) + ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area + labels = labels[ioa < 0.60] # remove >60% obscured labels + + return labels + + +def create_folder(path='./new'): + # Create folder + if os.path.exists(path): + shutil.rmtree(path) # delete output folder + os.makedirs(path) # make new output folder + + +def flatten_recursive(path='../coco128'): + # Flatten a recursive directory by bringing all files to top level + new_path = Path(path + '_flat') + create_folder(new_path) + for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)): + shutil.copyfile(file, new_path / Path(file).name) + + +def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128') + # Convert detection dataset into classification dataset, with one directory per class + + path = Path(path) # images dir + shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing + files = list(path.rglob('*.*')) + n = len(files) # number of files + for im_file in tqdm(files, total=n): + if im_file.suffix[1:] in img_formats: + # image + im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB + h, w = im.shape[:2] + + # labels + lb_file = Path(img2label_paths([str(im_file)])[0]) + if Path(lb_file).exists(): + with open(lb_file, 'r') as f: + lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels + + for j, x in enumerate(lb): + c = int(x[0]) # class + f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename + if not f.parent.is_dir(): + f.parent.mkdir(parents=True) + + b = x[1:] * [w, h, w, h] # box + # b[2:] = b[2:].max() # rectangle to square + b[2:] = b[2:] * 1.2 + 3 # pad + b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int) + + b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image + b[[1, 3]] = np.clip(b[[1, 3]], 0, h) + assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}' + + +def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128') + """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files + # Arguments + path: Path to images directory + weights: Train, val, test weights (list) + """ + path = Path(path) # images dir + files = list(path.rglob('*.*')) + n = len(files) # number of files + indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split + txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files + [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing + for i, img in tqdm(zip(indices, files), total=n): + if img.suffix[1:] in img_formats: + with open(path / txt[i], 'a') as f: + f.write(str(img) + '\n') # add image to txt file diff --git a/data_processing/yolov5_crowdhuman/utils/general.py b/data_processing/yolov5_crowdhuman/utils/general.py new file mode 100644 index 0000000..3b5f462 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/general.py @@ -0,0 +1,574 @@ +# General utils + +import glob +import logging +import math +import os +import platform +import random +import re +import subprocess +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torchvision +import yaml + +from utils.google_utils import gsutil_getsize +from utils.metrics import fitness +from utils.torch_utils import init_torch_seeds + +# Settings +torch.set_printoptions(linewidth=320, precision=5, profile='long') +np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 +cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) +os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads + + +def set_logging(rank=-1): + logging.basicConfig( + format="%(message)s", + level=logging.INFO if rank in [-1, 0] else logging.WARN) + + +def init_seeds(seed=0): + # Initialize random number generator (RNG) seeds + random.seed(seed) + np.random.seed(seed) + init_torch_seeds(seed) + + +def get_latest_run(search_dir='.'): + # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) + last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) + return max(last_list, key=os.path.getctime) if last_list else '' + + +def isdocker(): + # Is environment a Docker container + return Path('/workspace').exists() # or Path('/.dockerenv').exists() + + +def check_online(): + # Check internet connectivity + import socket + try: + socket.create_connection(("1.1.1.1", 443), 5) # check host accesability + return True + except OSError: + return False + + +def check_git_status(): + # Recommend 'git pull' if code is out of date + print(colorstr('github: '), end='') + try: + assert Path('.git').exists(), 'skipping check (not a git repository)' + assert not isdocker(), 'skipping check (Docker image)' + assert check_online(), 'skipping check (offline)' + + cmd = 'git fetch && git config --get remote.origin.url' + url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url + branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out + n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind + if n > 0: + s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \ + f"Use 'git pull' to update or 'git clone {url}' to download latest." + else: + s = f'up to date with {url} ✅' + print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) + except Exception as e: + print(e) + + +def check_requirements(file='requirements.txt', exclude=()): + # Check installed dependencies meet requirements + import pkg_resources + requirements = [f'{x.name}{x.specifier}' for x in pkg_resources.parse_requirements(Path(file).open()) + if x.name not in exclude] + pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + if new_size != img_size: + print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size)) + return new_size + + +def check_imshow(): + # Check if environment supports image displays + try: + assert not isdocker(), 'cv2.imshow() is disabled in Docker environments' + cv2.imshow('test', np.zeros((1, 1, 3))) + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') + return False + + +def check_file(file): + # Search for file if not found + if os.path.isfile(file) or file == '': + return file + else: + files = glob.glob('./**/' + file, recursive=True) # find file + assert len(files), 'File Not Found: %s' % file # assert file was found + assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique + return files[0] # return file + + +def check_dataset(dict): + # Download dataset if not found locally + val, s = dict.get('val'), dict.get('download') + if val and len(val): + val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path + if not all(x.exists() for x in val): + print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()]) + if s and len(s): # download script + print('Downloading %s ...' % s) + if s.startswith('http') and s.endswith('.zip'): # URL + f = Path(s).name # filename + torch.hub.download_url_to_file(s, f) + r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip + else: # bash script + r = os.system(s) + print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value + else: + raise Exception('Dataset not found.') + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def clean_str(s): + # Cleans a string by replacing special characters with underscore _ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def one_cycle(y1=0.0, y2=1.0, steps=100): + # lambda function for sinusoidal ramp from y1 to y2 + return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 + + +def colorstr(*input): + # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') + *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string + colors = {'black': '\033[30m', # basic colors + 'red': '\033[31m', + 'green': '\033[32m', + 'yellow': '\033[33m', + 'blue': '\033[34m', + 'magenta': '\033[35m', + 'cyan': '\033[36m', + 'white': '\033[37m', + 'bright_black': '\033[90m', # bright colors + 'bright_red': '\033[91m', + 'bright_green': '\033[92m', + 'bright_yellow': '\033[93m', + 'bright_blue': '\033[94m', + 'bright_magenta': '\033[95m', + 'bright_cyan': '\033[96m', + 'bright_white': '\033[97m', + 'end': '\033[0m', # misc + 'bold': '\033[1m', + 'underline': '\033[4m'} + return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] + + +def labels_to_class_weights(labels, nc=80): + # Get class weights (inverse frequency) from training labels + if labels[0] is None: # no labels loaded + return torch.Tensor() + + labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO + classes = labels[:, 0].astype(np.int) # labels = [class xywh] + weights = np.bincount(classes, minlength=nc) # occurrences per class + + # Prepend gridpoint count (for uCE training) + # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image + # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start + + weights[weights == 0] = 1 # replace empty bins with 1 + weights = 1 / weights # number of targets per class + weights /= weights.sum() # normalize + return torch.from_numpy(weights) + + +def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): + # Produces image weights based on class_weights and image contents + class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels]) + image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1) + # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample + return image_weights + + +def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) + # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ + # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') + # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') + # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco + # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet + x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + return x + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x + y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y + y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x + y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y + return y + + +def xyn2xy(x, w=640, h=640, padw=0, padh=0): + # Convert normalized segments into pixel segments, shape (n,2) + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = w * x[:, 0] + padw # top left x + y[:, 1] = h * x[:, 1] + padh # top left y + return y + + +def segment2box(segment, width=640, height=640): + # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) + x, y = segment.T # segment xy + inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) + x, y, = x[inside], y[inside] + return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # cls, xyxy + + +def segments2boxes(segments): + # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + # Up-sample an (n,2) segment + for i, s in enumerate(segments): + x = np.linspace(0, len(s) - 1, n) + xp = np.arange(len(s)) + segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy + return segments + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9): + # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 + box2 = box2.T + + # Get the coordinates of bounding boxes + if x1y1x2y2: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: # transform from xywh to xyxy + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + # Union Area + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + union = w1 * h1 + w2 * h2 - inter + eps + + iou = inter / union + if GIoU or DIoU or CIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared + if DIoU: + return iou - rho2 / c2 # DIoU + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / ((1 + eps) - iou + v) + return iou - (rho2 / c2 + v * alpha) # CIoU + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU + else: + return iou # IoU + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + + +def wh_iou(wh1, wh2): + # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 + wh1 = wh1[:, None] # [N,1,2] + wh2 = wh2[None] # [1,M,2] + inter = torch.min(wh1, wh2).prod(2) # [N,M] + return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, + labels=()): + """Runs Non-Maximum Suppression (NMS) on inference results + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 5), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') + break # time limit exceeded + + return output + + +def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() + # Strip optimizer from 'f' to finalize training, optionally save as 's' + x = torch.load(f, map_location=torch.device('cpu')) + for key in 'optimizer', 'training_results', 'wandb_id': + x[key] = None + x['epoch'] = -1 + x['model'].half() # to FP16 + for p in x['model'].parameters(): + p.requires_grad = False + torch.save(x, s or f) + mb = os.path.getsize(s or f) / 1E6 # filesize + print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb)) + + +def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): + # Print mutation results to evolve.txt (for use with train.py --evolve) + a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys + b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values + c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) + print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) + + if bucket: + url = 'gs://%s/evolve.txt' % bucket + if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0): + os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local + + with open('evolve.txt', 'a') as f: # append result + f.write(c + b + '\n') + x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows + x = x[np.argsort(-fitness(x))] # sort + np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness + + # Save yaml + for i, k in enumerate(hyp.keys()): + hyp[k] = float(x[0, i + 7]) + with open(yaml_file, 'w') as f: + results = tuple(x[0, :7]) + c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3) + f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') + yaml.dump(hyp, f, sort_keys=False) + + if bucket: + os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload + + +def apply_classifier(x, model, img, im0): + # applies a second stage classifier to yolo outputs + im0 = [im0] if isinstance(im0, np.ndarray) else im0 + for i, d in enumerate(x): # per image + if d is not None and len(d): + d = d.clone() + + # Reshape and pad cutouts + b = xyxy2xywh(d[:, :4]) # boxes + b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square + b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad + d[:, :4] = xywh2xyxy(b).long() + + # Rescale boxes from img_size to im0 size + scale_coords(img.shape[2:], d[:, :4], im0[i].shape) + + # Classes + pred_cls1 = d[:, 5].long() + ims = [] + for j, a in enumerate(d): # per item + cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])] + im = cv2.resize(cutout, (224, 224)) # BGR + # cv2.imwrite('test%i.jpg' % j, cutout) + + im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32 + im /= 255.0 # 0 - 255 to 0.0 - 1.0 + ims.append(im) + + pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction + x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections + + return x + + +def increment_path(path, exist_ok=True, sep=''): + # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. + path = Path(path) # os-agnostic + if (path.exists() and exist_ok) or (not path.exists()): + return str(path) + else: + dirs = glob.glob(f"{path}{sep}*") # similar paths + matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] + i = [int(m.groups()[0]) for m in matches if m] # indices + n = max(i) + 1 if i else 2 # increment number + return f"{path}{sep}{n}" # update path diff --git a/data_processing/yolov5_crowdhuman/utils/google_app_engine/Dockerfile b/data_processing/yolov5_crowdhuman/utils/google_app_engine/Dockerfile new file mode 100644 index 0000000..0155618 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/google_app_engine/Dockerfile @@ -0,0 +1,25 @@ +FROM gcr.io/google-appengine/python + +# Create a virtualenv for dependencies. This isolates these packages from +# system-level packages. +# Use -p python3 or -p python3.7 to select python version. Default is version 2. +RUN virtualenv /env -p python3 + +# Setting these environment variables are the same as running +# source /env/bin/activate. +ENV VIRTUAL_ENV /env +ENV PATH /env/bin:$PATH + +RUN apt-get update && apt-get install -y python-opencv + +# Copy the application's requirements.txt and run pip to install all +# dependencies into the virtualenv. +ADD requirements.txt /app/requirements.txt +RUN pip install -r /app/requirements.txt + +# Add the application source code. +ADD . /app + +# Run a WSGI server to serve the application. gunicorn must be declared as +# a dependency in requirements.txt. +CMD gunicorn -b :$PORT main:app diff --git a/data_processing/yolov5_crowdhuman/utils/google_app_engine/additional_requirements.txt b/data_processing/yolov5_crowdhuman/utils/google_app_engine/additional_requirements.txt new file mode 100644 index 0000000..5fcc305 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/google_app_engine/additional_requirements.txt @@ -0,0 +1,4 @@ +# add these requirements in your app on top of the existing ones +pip==18.1 +Flask==1.0.2 +gunicorn==19.9.0 diff --git a/data_processing/yolov5_crowdhuman/utils/google_app_engine/app.yaml b/data_processing/yolov5_crowdhuman/utils/google_app_engine/app.yaml new file mode 100644 index 0000000..ac29d10 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/google_app_engine/app.yaml @@ -0,0 +1,14 @@ +runtime: custom +env: flex + +service: yolov5app + +liveness_check: + initial_delay_sec: 600 + +manual_scaling: + instances: 1 +resources: + cpu: 1 + memory_gb: 4 + disk_size_gb: 20 \ No newline at end of file diff --git a/data_processing/yolov5_crowdhuman/utils/google_utils.py b/data_processing/yolov5_crowdhuman/utils/google_utils.py new file mode 100644 index 0000000..0a7ca3b --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/google_utils.py @@ -0,0 +1,122 @@ +# Google utils: https://cloud.google.com/storage/docs/reference/libraries + +import os +import platform +import subprocess +import time +from pathlib import Path + +import requests +import torch + + +def gsutil_getsize(url=''): + # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du + s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8') + return eval(s.split(' ')[0]) if len(s) else 0 # bytes + + +def attempt_download(file, repo='ultralytics/yolov5'): + # Attempt file download if does not exist + file = Path(str(file).strip().replace("'", '').lower()) + + if not file.exists(): + try: + response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api + assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] + tag = response['tag_name'] # i.e. 'v1.0' + except: # fallback plan + assets = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt'] + tag = subprocess.check_output('git tag', shell=True).decode().split()[-1] + + name = file.name + if name in assets: + msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' + redundant = False # second download option + try: # GitHub + url = f'https://github.com/{repo}/releases/download/{tag}/{name}' + print(f'Downloading {url} to {file}...') + torch.hub.download_url_to_file(url, file) + assert file.exists() and file.stat().st_size > 1E6 # check + except Exception as e: # GCP + print(f'Download error: {e}') + assert redundant, 'No secondary mirror' + url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' + print(f'Downloading {url} to {file}...') + os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights) + finally: + if not file.exists() or file.stat().st_size < 1E6: # check + file.unlink(missing_ok=True) # remove partial downloads + print(f'ERROR: Download failure: {msg}') + print('') + return + + +def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): + # Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download() + t = time.time() + file = Path(file) + cookie = Path('cookie') # gdrive cookie + print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='') + file.unlink(missing_ok=True) # remove existing file + cookie.unlink(missing_ok=True) # remove existing cookie + + # Attempt file download + out = "NUL" if platform.system() == "Windows" else "/dev/null" + os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') + if os.path.exists('cookie'): # large file + s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}' + else: # small file + s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"' + r = os.system(s) # execute, capture return + cookie.unlink(missing_ok=True) # remove existing cookie + + # Error check + if r != 0: + file.unlink(missing_ok=True) # remove partial + print('Download error ') # raise Exception('Download error') + return r + + # Unzip if archive + if file.suffix == '.zip': + print('unzipping... ', end='') + os.system(f'unzip -q {file}') # unzip + file.unlink() # remove zip to free space + + print(f'Done ({time.time() - t:.1f}s)') + return r + + +def get_token(cookie="./cookie"): + with open(cookie) as f: + for line in f: + if "download" in line: + return line.split()[-1] + return "" + +# def upload_blob(bucket_name, source_file_name, destination_blob_name): +# # Uploads a file to a bucket +# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python +# +# storage_client = storage.Client() +# bucket = storage_client.get_bucket(bucket_name) +# blob = bucket.blob(destination_blob_name) +# +# blob.upload_from_filename(source_file_name) +# +# print('File {} uploaded to {}.'.format( +# source_file_name, +# destination_blob_name)) +# +# +# def download_blob(bucket_name, source_blob_name, destination_file_name): +# # Uploads a blob from a bucket +# storage_client = storage.Client() +# bucket = storage_client.get_bucket(bucket_name) +# blob = bucket.blob(source_blob_name) +# +# blob.download_to_filename(destination_file_name) +# +# print('Blob {} downloaded to {}.'.format( +# source_blob_name, +# destination_file_name)) diff --git a/data_processing/yolov5_crowdhuman/utils/loss.py b/data_processing/yolov5_crowdhuman/utils/loss.py new file mode 100644 index 0000000..2302d18 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/loss.py @@ -0,0 +1,216 @@ +# Loss functions + +import torch +import torch.nn as nn + +from utils.general import bbox_iou +from utils.torch_utils import is_parallel + + +def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + # return positive, negative label smoothing BCE targets + return 1.0 - 0.5 * eps, 0.5 * eps + + +class BCEBlurWithLogitsLoss(nn.Module): + # BCEwithLogitLoss() with reduced missing label effects. + def __init__(self, alpha=0.05): + super(BCEBlurWithLogitsLoss, self).__init__() + self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss() + self.alpha = alpha + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + pred = torch.sigmoid(pred) # prob from logits + dx = pred - true # reduce only missing label effects + # dx = (pred - true).abs() # reduce missing label and false label effects + alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4)) + loss *= alpha_factor + return loss.mean() + + +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super(FocalLoss, self).__init__() + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = torch.sigmoid(pred) # prob from logits + p_t = true * pred_prob + (1 - true) * (1 - pred_prob) + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = (1.0 - p_t) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + +class QFocalLoss(nn.Module): + # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super(QFocalLoss, self).__init__() + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + + pred_prob = torch.sigmoid(pred) # prob from logits + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = torch.abs(true - pred_prob) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + +class ComputeLoss: + # Compute losses + def __init__(self, model, autobalance=False): + super(ComputeLoss, self).__init__() + device = next(model.parameters()).device # get model device + h = model.hyp # hyperparameters + + # Define criteria + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + self.cp, self.cn = smooth_BCE(eps=0.0) + + # Focal loss + g = h['fl_gamma'] # focal loss gamma + if g > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module + self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 + self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index + self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance + for k in 'na', 'nc', 'nl', 'anchors': + setattr(self, k, getattr(det, k)) + + def __call__(self, p, targets): # predictions, targets, model + device = targets.device + lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) + tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets + + # Losses + for i, pi in enumerate(p): # layer index, layer predictions + b, a, gj, gi = indices[i] # image, anchor, gridy, gridx + tobj = torch.zeros_like(pi[..., 0], device=device) # target obj + + n = b.shape[0] # number of targets + if n: + ps = pi[b, a, gj, gi] # prediction subset corresponding to targets + + # Regression + pxy = ps[:, :2].sigmoid() * 2. - 0.5 + pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1) # predicted box + iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio + + # Classification + if self.nc > 1: # cls loss (only if multiple classes) + t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets + t[range(n), tcls[i]] = self.cp + lcls += self.BCEcls(ps[:, 5:], t) # BCE + + # Append targets to text file + # with open('targets.txt', 'a') as file: + # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] + + obji = self.BCEobj(pi[..., 4], tobj) + lobj += obji * self.balance[i] # obj loss + if self.autobalance: + self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() + + if self.autobalance: + self.balance = [x / self.balance[self.ssi] for x in self.balance] + lbox *= self.hyp['box'] + lobj *= self.hyp['obj'] + lcls *= self.hyp['cls'] + bs = tobj.shape[0] # batch size + + loss = lbox + lobj + lcls + return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() + + def build_targets(self, p, targets): + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + na, nt = self.na, targets.shape[0] # number of anchors, targets + tcls, tbox, indices, anch = [], [], [], [] + gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) + targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor([[0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=targets.device).float() * g # offsets + + for i in range(self.nl): + anchors = self.anchors[i] + gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain + if nt: + # Matches + r = t[:, :, 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare + # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1. < g) & (gxy > 1.)).T + l, m = ((gxi % 1. < g) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + b, c = t[:, :2].long().T # image, class + gxy = t[:, 2:4] # grid xy + gwh = t[:, 4:6] # grid wh + gij = (gxy - offsets).long() + gi, gj = gij.T # grid xy indices + + # Append + a = t[:, 6].long() # anchor indices + indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices + tbox.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors[a]) # anchors + tcls.append(c) # class + + return tcls, tbox, indices, anch diff --git a/data_processing/yolov5_crowdhuman/utils/metrics.py b/data_processing/yolov5_crowdhuman/utils/metrics.py new file mode 100644 index 0000000..ba812ff --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/metrics.py @@ -0,0 +1,223 @@ +# Model validation metrics + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from . import general + + +def fitness(x): + # Model fitness as a weighted combination of metrics + w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] + return (x[:, :4] * w).sum(1) + + +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): + """ Compute the average precision, given the recall and precision curves. + Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. + # Arguments + tp: True positives (nparray, nx1 or nx10). + conf: Objectness value from 0-1 (nparray). + pred_cls: Predicted object classes (nparray). + target_cls: True object classes (nparray). + plot: Plot precision-recall curve at mAP@0.5 + save_dir: Plot save directory + # Returns + The average precision as computed in py-faster-rcnn. + """ + + # Sort by objectness + i = np.argsort(-conf) + tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + + # Find unique classes + unique_classes = np.unique(target_cls) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + px, py = np.linspace(0, 1, 1000), [] # for plotting + ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + for ci, c in enumerate(unique_classes): + i = pred_cls == c + n_l = (target_cls == c).sum() # number of labels + n_p = i.sum() # number of predictions + + if n_p == 0 or n_l == 0: + continue + else: + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (n_l + 1e-16) # recall curve + r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) + if plot and j == 0: + py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 + + # Compute F1 (harmonic mean of precision and recall) + f1 = 2 * p * r / (p + r + 1e-16) + if plot: + plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) + plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') + plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') + plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') + + i = f1.mean(0).argmax() # max F1 index + return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') + + +def compute_ap(recall, precision): + """ Compute the average precision, given the recall and precision curves + # Arguments + recall: The recall curve (list) + precision: The precision curve (list) + # Returns + Average precision, precision curve, recall curve + """ + + # Append sentinel values to beginning and end + mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) + mpre = np.concatenate(([1.], precision, [0.])) + + # Compute the precision envelope + mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) + + # Integrate area under curve + method = 'interp' # methods: 'continuous', 'interp' + if method == 'interp': + x = np.linspace(0, 1, 101) # 101-point interp (COCO) + ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate + else: # 'continuous' + i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve + + return ap, mpre, mrec + + +class ConfusionMatrix: + # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix + def __init__(self, nc, conf=0.25, iou_thres=0.45): + self.matrix = np.zeros((nc + 1, nc + 1)) + self.nc = nc # number of classes + self.conf = conf + self.iou_thres = iou_thres + + def process_batch(self, detections, labels): + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + detections (Array[N, 6]), x1, y1, x2, y2, conf, class + labels (Array[M, 5]), class, x1, y1, x2, y2 + Returns: + None, updates confusion matrix accordingly + """ + detections = detections[detections[:, 4] > self.conf] + gt_classes = labels[:, 0].int() + detection_classes = detections[:, 5].int() + iou = general.box_iou(labels[:, 1:], detections[:, :4]) + + x = torch.where(iou > self.iou_thres) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + else: + matches = np.zeros((0, 3)) + + n = matches.shape[0] > 0 + m0, m1, _ = matches.transpose().astype(np.int16) + for i, gc in enumerate(gt_classes): + j = m0 == i + if n and sum(j) == 1: + self.matrix[gc, detection_classes[m1[j]]] += 1 # correct + else: + self.matrix[gc, self.nc] += 1 # background FP + + if n: + for i, dc in enumerate(detection_classes): + if not any(m1 == i): + self.matrix[self.nc, dc] += 1 # background FN + + def matrix(self): + return self.matrix + + def plot(self, save_dir='', names=()): + try: + import seaborn as sn + + array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig = plt.figure(figsize=(12, 9), tight_layout=True) + sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size + labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels + sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, + xticklabels=names + ['background FN'] if labels else "auto", + yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1)) + fig.axes[0].set_xlabel('True') + fig.axes[0].set_ylabel('Predicted') + fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) + except Exception as e: + pass + + def print(self): + for i in range(self.nc + 1): + print(' '.join(map(str, self.matrix[i]))) + + +# Plots ---------------------------------------------------------------------------------------------------------------- + +def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): + # Precision-recall curve + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + py = np.stack(py, axis=1) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + + ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + fig.savefig(Path(save_dir), dpi=250) + + +def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): + # Metric-confidence curve + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py): + ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) + else: + ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) + + y = py.mean(0) + ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + fig.savefig(Path(save_dir), dpi=250) diff --git a/data_processing/yolov5_crowdhuman/utils/plots.py b/data_processing/yolov5_crowdhuman/utils/plots.py new file mode 100644 index 0000000..ca54fdc --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/plots.py @@ -0,0 +1,429 @@ +# Plotting utils + +import glob +import math +import os +import random +from copy import copy +from pathlib import Path + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +import yaml +from PIL import Image, ImageDraw, ImageFont +from scipy.signal import butter, filtfilt + +from utils.general import xywh2xyxy, xyxy2xywh +from utils.metrics import fitness + +# Settings +matplotlib.rc('font', **{'size': 11}) +matplotlib.use('Agg') # for writing to files only + + +def color_list(): + # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb + def hex2rgb(h): + return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + + return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949) + + +def hist2d(x, y, n=100): + # 2d histogram used in labels.png and evolve.png + xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) + hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) + xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) + yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) + return np.log(hist[xidx, yidx]) + + +def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): + # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy + def butter_lowpass(cutoff, fs, order): + nyq = 0.5 * fs + normal_cutoff = cutoff / nyq + return butter(order, normal_cutoff, btype='low', analog=False) + + b, a = butter_lowpass(cutoff, fs, order=order) + return filtfilt(b, a, data) # forward-backward filter + + +def plot_one_box(x, img, color=None, label=None, line_thickness=3): + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + tl=1 + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label and False: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) + + +def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None): + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + line_thickness = line_thickness or max(int(min(img.size) / 200), 2) + draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot + if label: + fontsize = max(round(max(img.size) / 40), 12) + font = ImageFont.truetype("Arial.ttf", fontsize) + txt_width, txt_height = font.getsize(label) + draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) + draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) + return np.asarray(img) + + +def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() + # Compares the two methods for width-height anchor multiplication + # https://github.com/ultralytics/yolov3/issues/168 + x = np.arange(-4.0, 4.0, .1) + ya = np.exp(x) + yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 + + fig = plt.figure(figsize=(6, 3), tight_layout=True) + plt.plot(x, ya, '.-', label='YOLOv3') + plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2') + plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6') + plt.xlim(left=-4, right=4) + plt.ylim(bottom=0, top=6) + plt.xlabel('input') + plt.ylabel('output') + plt.grid() + plt.legend() + fig.savefig('comparison.png', dpi=200) + + +def output_to_target(output): + # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] + targets = [] + for i, o in enumerate(output): + for *box, conf, cls in o.cpu().numpy(): + targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf]) + return np.array(targets) + + +def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): + # Plot image grid with labels + + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(targets, torch.Tensor): + targets = targets.cpu().numpy() + + # un-normalise + if np.max(images[0]) <= 1: + images *= 255 + + tl = 3 # line thickness + tf = max(tl - 1, 1) # font thickness + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs ** 0.5) # number of subplots (square) + + # Check if we should resize + scale_factor = max_size / max(h, w) + if scale_factor < 1: + h = math.ceil(scale_factor * h) + w = math.ceil(scale_factor * w) + + colors = color_list() # list of colors + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i, img in enumerate(images): + if i == max_subplots: # if last batch has fewer images than we expect + break + + block_x = int(w * (i // ns)) + block_y = int(h * (i % ns)) + + img = img.transpose(1, 2, 0) + if scale_factor < 1: + img = cv2.resize(img, (w, h)) + + mosaic[block_y:block_y + h, block_x:block_x + w, :] = img + if len(targets) > 0: + image_targets = targets[targets[:, 0] == i] + boxes = xywh2xyxy(image_targets[:, 2:6]).T + classes = image_targets[:, 1].astype('int') + labels = image_targets.shape[1] == 6 # labels if no conf column + conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred) + + if boxes.shape[1]: + if boxes.max() <= 1.01: # if normalized with tolerance 0.01 + boxes[[0, 2]] *= w # scale to pixels + boxes[[1, 3]] *= h + elif scale_factor < 1: # absolute coords need scale if image scales + boxes *= scale_factor + boxes[[0, 2]] += block_x + boxes[[1, 3]] += block_y + for j, box in enumerate(boxes.T): + cls = int(classes[j]) + color = colors[cls % len(colors)] + cls = names[cls] if names else cls + if labels or conf[j] > 0.25: # 0.25 conf thresh + label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) + plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl) + + # Draw image filename labels + if paths: + label = Path(paths[i]).name[:40] # trim to 40 char + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf, + lineType=cv2.LINE_AA) + + # Image border + cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) + + if fname: + r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size + mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA) + # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save + Image.fromarray(mosaic).save(fname) # PIL save + return mosaic + + +def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): + # Plot LR simulating training for full epochs + optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals + y = [] + for _ in range(epochs): + scheduler.step() + y.append(optimizer.param_groups[0]['lr']) + plt.plot(y, '.-', label='LR') + plt.xlabel('epoch') + plt.ylabel('LR') + plt.grid() + plt.xlim(0, epochs) + plt.ylim(0) + plt.savefig(Path(save_dir) / 'LR.png', dpi=200) + plt.close() + + +def plot_test_txt(): # from utils.plots import *; plot_test() + # Plot test.txt histograms + x = np.loadtxt('test.txt', dtype=np.float32) + box = xyxy2xywh(x[:, :4]) + cx, cy = box[:, 0], box[:, 1] + + fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True) + ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) + ax.set_aspect('equal') + plt.savefig('hist2d.png', dpi=300) + + fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) + ax[0].hist(cx, bins=600) + ax[1].hist(cy, bins=600) + plt.savefig('hist1d.png', dpi=200) + + +def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() + # Plot targets.txt histograms + x = np.loadtxt('targets.txt', dtype=np.float32).T + s = ['x targets', 'y targets', 'width targets', 'height targets'] + fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True) + ax = ax.ravel() + for i in range(4): + ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std())) + ax[i].legend() + ax[i].set_title(s[i]) + plt.savefig('targets.jpg', dpi=200) + + +def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() + # Plot study.txt generated by test.py + fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) + # ax = ax.ravel() + + fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) + # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s', 'yolov5m', 'yolov5l', 'yolov5x']]: + for f in sorted(Path(path).glob('study*.txt')): + y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T + x = np.arange(y.shape[1]) if x is None else np.array(x) + s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] + # for i in range(7): + # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8) + # ax[i].set_title(s[i]) + + j = y[3].argmax() + 1 + ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8, + label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO')) + + ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5], + 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') + + ax2.grid(alpha=0.2) + ax2.set_yticks(np.arange(20, 60, 5)) + ax2.set_xlim(0, 30) + ax2.set_ylim(30, 55) + ax2.set_xlabel('GPU Speed (ms/img)') + ax2.set_ylabel('COCO AP val') + ax2.legend(loc='lower right') + plt.savefig(str(Path(path).name) + '.png', dpi=300) + + +def plot_labels(labels, save_dir=Path(''), loggers=None): + # plot dataset labels + print('Plotting labels... ') + c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes + nc = int(c.max() + 1) # number of classes + colors = color_list() + x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) + + # seaborn correlogram + sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) + plt.close() + + # matplotlib labels + matplotlib.use('svg') # faster + ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() + ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) + ax[0].set_xlabel('classes') + sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) + sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) + + # rectangles + labels[:, 1:3] = 0.5 # center + labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 + img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) + for cls, *box in labels[:1000]: + ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot + ax[1].imshow(img) + ax[1].axis('off') + + for a in [0, 1, 2, 3]: + for s in ['top', 'right', 'left', 'bottom']: + ax[a].spines[s].set_visible(False) + + plt.savefig(save_dir / 'labels.jpg', dpi=200) + matplotlib.use('Agg') + plt.close() + + # loggers + for k, v in loggers.items() or {}: + if k == 'wandb' and v: + v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False) + + +def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() + # Plot hyperparameter evolution results in evolve.txt + with open(yaml_file) as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) + x = np.loadtxt('evolve.txt', ndmin=2) + f = fitness(x) + # weights = (f - f.min()) ** 2 # for weighted results + plt.figure(figsize=(10, 12), tight_layout=True) + matplotlib.rc('font', **{'size': 8}) + for i, (k, v) in enumerate(hyp.items()): + y = x[:, i + 7] + # mu = (y * weights).sum() / weights.sum() # best weighted result + mu = y[f.argmax()] # best single result + plt.subplot(6, 5, i + 1) + plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') + plt.plot(mu, f.max(), 'k+', markersize=15) + plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters + if i % 5 != 0: + plt.yticks([]) + print('%15s: %.3g' % (k, mu)) + plt.savefig('evolve.png', dpi=200) + print('\nPlot saved as evolve.png') + + +def profile_idetection(start=0, stop=0, labels=(), save_dir=''): + # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() + ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() + s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS'] + files = list(Path(save_dir).glob('frames*.txt')) + for fi, f in enumerate(files): + try: + results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows + n = results.shape[1] # number of rows + x = np.arange(start, min(stop, n) if stop else n) + results = results[:, x] + t = (results[0] - results[0].min()) # set t0=0s + results[0] = x + for i, a in enumerate(ax): + if i < len(results): + label = labels[fi] if len(labels) else f.stem.replace('frames_', '') + a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5) + a.set_title(s[i]) + a.set_xlabel('time (s)') + # if fi == len(files) - 1: + # a.set_ylim(bottom=0) + for side in ['top', 'right']: + a.spines[side].set_visible(False) + else: + a.remove() + except Exception as e: + print('Warning: Plotting error for %s; %s' % (f, e)) + + ax[1].legend() + plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) + + +def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() + # Plot training 'results*.txt', overlaying train and val losses + s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends + t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles + for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): + results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T + n = results.shape[1] # number of rows + x = range(start, min(stop, n) if stop else n) + fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) + ax = ax.ravel() + for i in range(5): + for j in [i, i + 5]: + y = results[j, x] + ax[i].plot(x, y, marker='.', label=s[j]) + # y_smooth = butter_lowpass_filtfilt(y) + # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) + + ax[i].set_title(t[i]) + ax[i].legend() + ax[i].set_ylabel(f) if i == 0 else None # add filename + fig.savefig(f.replace('.txt', '.png'), dpi=200) + + +def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): + # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') + fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) + ax = ax.ravel() + s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', + 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] + if bucket: + # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] + files = ['results%g.txt' % x for x in id] + c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) + os.system(c) + else: + files = list(Path(save_dir).glob('results*.txt')) + assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) + for fi, f in enumerate(files): + try: + results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T + n = results.shape[1] # number of rows + x = range(start, min(stop, n) if stop else n) + for i in range(10): + y = results[i, x] + if i in [0, 1, 2, 5, 6, 7]: + y[y == 0] = np.nan # don't show zero loss values + # y /= y[0] # normalize + label = labels[fi] if len(labels) else f.stem + ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) + ax[i].set_title(s[i]) + # if i in [5, 6, 7]: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + print('Warning: Plotting error for %s; %s' % (f, e)) + + ax[1].legend() + fig.savefig(Path(save_dir) / 'results.png', dpi=200) diff --git a/data_processing/yolov5_crowdhuman/utils/torch_utils.py b/data_processing/yolov5_crowdhuman/utils/torch_utils.py new file mode 100644 index 0000000..1b1cc20 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/torch_utils.py @@ -0,0 +1,294 @@ +# PyTorch utils + +import logging +import math +import os +import subprocess +import time +from contextlib import contextmanager +from copy import deepcopy +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + import thop # for FLOPS computation +except ImportError: + thop = None +logger = logging.getLogger(__name__) + + +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """ + Decorator to make all processes in distributed training wait for each local_master to do something. + """ + if local_rank not in [-1, 0]: + torch.distributed.barrier() + yield + if local_rank == 0: + torch.distributed.barrier() + + +def init_torch_seeds(seed=0): + # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html + torch.manual_seed(seed) + if seed == 0: # slower, more reproducible + cudnn.benchmark, cudnn.deterministic = False, True + else: # faster, less reproducible + cudnn.benchmark, cudnn.deterministic = True, False + + +def git_describe(): + # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe + if Path('.git').exists(): + return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1] + else: + return '' + + +def select_device(device='', batch_size=None): + # device = 'cpu' or '0' or '0,1,2,3' + s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string + cpu = device.lower() == 'cpu' + if cpu: + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False + elif device: # non-cpu device requested + os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable + assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability + + cuda = not cpu and torch.cuda.is_available() + if cuda: + n = torch.cuda.device_count() + if n > 1 and batch_size: # check that batch_size is compatible with device_count + assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' + space = ' ' * len(s) + for i, d in enumerate(device.split(',') if device else range(n)): + p = torch.cuda.get_device_properties(i) + s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB + else: + s += 'CPU\n' + + logger.info(s) # skip a line + return torch.device('cuda:0' if cuda else 'cpu') + + +def time_synchronized(): + # pytorch-accurate time + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.time() + + +def profile(x, ops, n=100, device=None): + # profile a pytorch module or list of modules. Example usage: + # x = torch.randn(16, 3, 640, 640) # input + # m1 = lambda x: x * torch.sigmoid(x) + # m2 = nn.SiLU() + # profile(x, [m1, m2], n=100) # profile speed over 100 iterations + + device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + x = x.to(device) + x.requires_grad = True + print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') + print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, 'to') else m # device + m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type + dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward + try: + flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS + except: + flops = 0 + + for _ in range(n): + t[0] = time_synchronized() + y = m(x) + t[1] = time_synchronized() + try: + _ = y.sum().backward() + t[2] = time_synchronized() + except: # no backward method + t[2] = float('nan') + dtf += (t[1] - t[0]) * 1000 / n # ms per op forward + dtb += (t[2] - t[1]) * 1000 / n # ms per op backward + + s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' + s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' + p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters + print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') + + +def is_parallel(model): + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + + +def intersect_dicts(da, db, exclude=()): + # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values + return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} + + +def initialize_weights(model): + for m in model.modules(): + t = type(m) + if t is nn.Conv2d: + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif t is nn.BatchNorm2d: + m.eps = 1e-3 + m.momentum = 0.03 + elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: + m.inplace = True + + +def find_modules(model, mclass=nn.Conv2d): + # Finds layer indices matching module class 'mclass' + return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] + + +def sparsity(model): + # Return global model sparsity + a, b = 0., 0. + for p in model.parameters(): + a += p.numel() + b += (p == 0).sum() + return b / a + + +def prune(model, amount=0.3): + # Prune model to requested global sparsity + import torch.nn.utils.prune as prune + print('Pruning model... ', end='') + for name, m in model.named_modules(): + if isinstance(m, nn.Conv2d): + prune.l1_unstructured(m, name='weight', amount=amount) # prune + prune.remove(m, 'weight') # make permanent + print(' %.3g global sparsity' % sparsity(model)) + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = nn.Conv2d(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True).requires_grad_(False).to(conv.weight.device) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def model_info(model, verbose=False, img_size=640): + # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] + n_p = sum(x.numel() for x in model.parameters()) # number parameters + n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients + if verbose: + print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) + for i, (name, p) in enumerate(model.named_parameters()): + name = name.replace('module_list.', '') + print('%5g %40s %9s %12g %20s %10.3g %10.3g' % + (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) + + try: # FLOPS + from thop import profile + stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 + img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input + flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS + img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float + fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS + except (ImportError, Exception): + fs = '' + + logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") + + +def load_classifier(name='resnet101', n=2): + # Loads a pretrained model reshaped to n-class output + model = torchvision.models.__dict__[name](pretrained=True) + + # ResNet model properties + # input_size = [3, 224, 224] + # input_space = 'RGB' + # input_range = [0, 1] + # mean = [0.485, 0.456, 0.406] + # std = [0.229, 0.224, 0.225] + + # Reshape output to n classes + filters = model.fc.weight.shape[1] + model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) + model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) + model.fc.out_features = n + return model + + +def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) + # scales img(bs,3,y,x) by ratio constrained to gs-multiple + if ratio == 1.0: + return img + else: + h, w = img.shape[2:] + s = (int(h * ratio), int(w * ratio)) # new size + img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + if not same_shape: # pad/crop img + h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] + return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (len(include) and k not in include) or k.startswith('_') or k in exclude: + continue + else: + setattr(a, k, v) + + +class ModelEMA: + """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models + Keep a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + A smoothed version of the weights is necessary for some training schemes to perform well. + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + + def __init__(self, model, decay=0.9999, updates=0): + # Create EMA + self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA + # if next(model.parameters()).device.type != 'cpu': + # self.ema.half() # FP16 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay(self.updates) + + msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1. - d) * msd[k].detach() + + def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + # Update EMA attributes + copy_attr(self.ema, model, include, exclude) diff --git a/data_processing/yolov5_crowdhuman/utils/wandb_logging/__init__.py b/data_processing/yolov5_crowdhuman/utils/wandb_logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_processing/yolov5_crowdhuman/utils/wandb_logging/log_dataset.py b/data_processing/yolov5_crowdhuman/utils/wandb_logging/log_dataset.py new file mode 100644 index 0000000..d790a9c --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/wandb_logging/log_dataset.py @@ -0,0 +1,39 @@ +import argparse +from pathlib import Path + +import yaml + +from wandb_utils import WandbLogger +from utils.datasets import LoadImagesAndLabels + +WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' + + +def create_dataset_artifact(opt): + with open(opt.data) as f: + data = yaml.load(f, Loader=yaml.SafeLoader) # data dict + logger = WandbLogger(opt, '', None, data, job_type='create_dataset') + nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names']) + names = {k: v for k, v in enumerate(names)} # to index dictionary + logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset + logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset + + # Update data.yaml with artifact links + data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train') + data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val') + path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path + data.pop('download', None) # download via artifact instead of predefined field 'download:' + with open(path, 'w') as f: + yaml.dump(data, f) + print("New Config file => ", path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') + parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') + parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project') + parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml') + opt = parser.parse_args() + + create_dataset_artifact(opt) diff --git a/data_processing/yolov5_crowdhuman/utils/wandb_logging/wandb_utils.py b/data_processing/yolov5_crowdhuman/utils/wandb_logging/wandb_utils.py new file mode 100644 index 0000000..264cd48 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/utils/wandb_logging/wandb_utils.py @@ -0,0 +1,145 @@ +import json +import shutil +import sys +from datetime import datetime +from pathlib import Path + +import torch + +sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path +from utils.general import colorstr, xywh2xyxy + +try: + import wandb +except ImportError: + wandb = None + print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") + +WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' + + +def remove_prefix(from_string, prefix): + return from_string[len(prefix):] + + +class WandbLogger(): + def __init__(self, opt, name, run_id, data_dict, job_type='Training'): + self.wandb = wandb + self.wandb_run = wandb.init(config=opt, resume="allow", + project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, + name=name, + job_type=job_type, + id=run_id) if self.wandb else None + + if job_type == 'Training': + self.setup_training(opt, data_dict) + if opt.bbox_interval == -1: + opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs + if opt.save_period == -1: + opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs + + def setup_training(self, opt, data_dict): + self.log_dict = {} + self.train_artifact_path, self.trainset_artifact = \ + self.download_dataset_artifact(data_dict['train'], opt.artifact_alias) + self.test_artifact_path, self.testset_artifact = \ + self.download_dataset_artifact(data_dict['val'], opt.artifact_alias) + self.result_artifact, self.result_table, self.weights = None, None, None + if self.train_artifact_path is not None: + train_path = Path(self.train_artifact_path) / 'data/images/' + data_dict['train'] = str(train_path) + if self.test_artifact_path is not None: + test_path = Path(self.test_artifact_path) / 'data/images/' + data_dict['val'] = str(test_path) + self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") + self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) + if opt.resume_from_artifact: + modeldir, _ = self.download_model_artifact(opt.resume_from_artifact) + if modeldir: + self.weights = Path(modeldir) / "best.pt" + opt.weights = self.weights + + def download_dataset_artifact(self, path, alias): + if path.startswith(WANDB_ARTIFACT_PREFIX): + dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) + assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" + datadir = dataset_artifact.download() + labels_zip = Path(datadir) / "data/labels.zip" + shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip') + print("Downloaded dataset to : ", datadir) + return datadir, dataset_artifact + return None, None + + def download_model_artifact(self, name): + model_artifact = wandb.use_artifact(name + ":latest") + assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' + modeldir = model_artifact.download() + print("Downloaded model to : ", modeldir) + return modeldir, model_artifact + + def log_model(self, path, opt, epoch): + datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ + 'original_url': str(path), + 'epoch': epoch + 1, + 'save period': opt.save_period, + 'project': opt.project, + 'datetime': datetime_suffix + }) + model_artifact.add_file(str(path / 'last.pt'), name='last.pt') + model_artifact.add_file(str(path / 'best.pt'), name='best.pt') + wandb.log_artifact(model_artifact) + print("Saving model artifact on epoch ", epoch + 1) + + def log_dataset_artifact(self, dataset, class_to_id, name='dataset'): + artifact = wandb.Artifact(name=name, type="dataset") + image_path = dataset.path + artifact.add_dir(image_path, name='data/images') + table = wandb.Table(columns=["id", "train_image", "Classes"]) + class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) + for si, (img, labels, paths, shapes) in enumerate(dataset): + height, width = shapes[0] + labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) + labels[:, 2:] *= torch.Tensor([width, height, width, height]) + box_data = [] + img_classes = {} + for cls, *xyxy in labels[:, 1:].tolist(): + cls = int(cls) + box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, + "class_id": cls, + "box_caption": "%s" % (class_to_id[cls]), + "scores": {"acc": 1}, + "domain": "pixel"}) + img_classes[cls] = class_to_id[cls] + boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space + table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes)) + artifact.add(table, name) + labels_path = 'labels'.join(image_path.rsplit('images', 1)) + zip_path = Path(labels_path).parent / (name + '_labels.zip') + if not zip_path.is_file(): # make_archive won't check if file exists + shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path) + artifact.add_file(str(zip_path), name='data/labels.zip') + wandb.log_artifact(artifact) + print("Saving data to W&B...") + + def log(self, log_dict): + if self.wandb_run: + for key, value in log_dict.items(): + self.log_dict[key] = value + + def end_epoch(self): + if self.wandb_run and self.log_dict: + wandb.log(self.log_dict) + self.log_dict = {} + + def finish_run(self): + if self.wandb_run: + if self.result_artifact: + print("Add Training Progress Artifact") + self.result_artifact.add(self.result_table, 'result') + train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id") + self.result_artifact.add(train_results, 'joined_result') + wandb.log_artifact(self.result_artifact) + if self.log_dict: + wandb.log(self.log_dict) + wandb.run.finish() diff --git a/data_processing/yolov5_crowdhuman/weights/download_weights.sh b/data_processing/yolov5_crowdhuman/weights/download_weights.sh new file mode 100644 index 0000000..bea00b1 --- /dev/null +++ b/data_processing/yolov5_crowdhuman/weights/download_weights.sh @@ -0,0 +1,7 @@ + +#!/bin/bash + +# use following link to download the weights + +# https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing + diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..53b9e22 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,8 @@ +name: text_to_3dportrait +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ec3c1fa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,95 @@ +albumentations==0.4.3 +diffusers==0.21.4 +opencv-python==4.1.2.30 +pudb==2019.2 +invisible-watermark +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +omegaconf==2.1.1 +test-tube>=0.7.5 +streamlit>=0.73.1 +einops==0.3.0 +torch-fidelity==0.3.0 +transformers==4.35.0 +torchmetrics==0.7.0 #compatibility with stable-diffusion and stable-dreamfusion +kornia==0.6 + +tifffile==2023.7.10 +imagecodecs + +tqdm==4.64.1 +rich==13.6.0 +ninja==1.11.1.1 +numpy==1.22.4 +networkx==3.1 +pandas==1.3.5 +scipy==1.9.1 +scikit-learn==1.3.1 +matplotlib==3.7.1 + +torch-ema==0.3 +einops==0.3.0 +tensorboard==2.13.0 +tensorboardX==2.6.2.2 +tensorboard-data-server==0.7.0 +tenacity==8.2.3 +chumpy==0.70 + +# for gui +dearpygui==1.10.1 + + +# for stable-diffusion +huggingface_hub +accelerate==0.20.3 + +# for dmtet and mesh export +xatlas==0.0.8 +trimesh==4.0.0 +PyMCubes==0.1.4 +pymeshlab==2022.2.post4 +PyWavelets==1.4.1 +git+https://github.com/NVlabs/nvdiffrast/ + +# for zero123 +carvekit-colab==4.1.0 +omegaconf==2.1.1 +pytorch-lightning==1.4.2 +taming-transformers-rom1504==0.0.6 +kornia==0.6.0 +git+https://github.com/openai/CLIP.git + +# for omnidata +gdown==4.7.1 + +# for dpt +timm==0.9.8 + +# for remote debugging +debugpy-run==1.6 + +# for deepfloyd if +sentencepiece==0.1.99 + +pyrender==0.1.45 +PyOpenGL==3.1.0 + +six==1.16.0 +smmap==5.0.1 +threadpoolctl==3.2.0 +tokenizers==0.14.1 +lazy-loader==0.3 +rpds-py==0.10.6 + + +google==3.0.0 +google-auth==2.18.0 +google-auth-oauthlib==1.0.0 +protobuf==3.20.3 + +smplx==0.1.28 +mrcfile + + +imgui==1.3.0 +glfw==2.2.0 \ No newline at end of file diff --git a/stable-diffusion/LICENSE b/stable-diffusion/LICENSE new file mode 100644 index 0000000..0e609df --- /dev/null +++ b/stable-diffusion/LICENSE @@ -0,0 +1,82 @@ +Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors + +CreativeML Open RAIL-M +dated August 22, 2022 + +Section I: PREAMBLE + +Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. + +Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. + +In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. + +Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. + +This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. + +NOW THEREFORE, You and Licensor agree as follows: + +1. Definitions + +- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. +- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. +- "Output" means the results of operating a Model as embodied in informational content resulting therefrom. +- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. +- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. +- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. +- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. +- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. +- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. +- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. +- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." +- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. + +Section II: INTELLECTUAL PROPERTY RIGHTS + +Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. +3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. + +Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION + +4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: +Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. +You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; +You must cause any modified files to carry prominent notices stating that You changed the files; +You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. +5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). +6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. + +Section IV: OTHER PROVISIONS + +7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. +8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. +9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. +10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. +11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. + +END OF TERMS AND CONDITIONS + + + + +Attachment A + +Use Restrictions + +You agree not to use the Model or Derivatives of the Model: +- In any way that violates any applicable national, federal, state, local or international law or regulation; +- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +- To generate or disseminate verifiably false information and/or content with the purpose of harming others; +- To generate or disseminate personal identifiable information that can be used to harm an individual; +- To defame, disparage or otherwise harass others; +- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; +- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; +- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; +- To provide medical advice and medical results interpretation; +- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). diff --git a/stable-diffusion/README.md b/stable-diffusion/README.md new file mode 100644 index 0000000..c3f7efa --- /dev/null +++ b/stable-diffusion/README.md @@ -0,0 +1,227 @@ + + +# What's new : + +Commit `21f890f9da3cfbeaba8e2ac3c425ee9e998d5229`, branch `main` + +1. `./scripts/txt2realistic_human.py` +2. `./get_test_data_df.py` + + + +# Original README + +## Stable Diffusion + +*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* + +[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
+[Robin Rombach](https://github.com/rromb)\*, +[Andreas Blattmann](https://github.com/ablattmann)\*, +[Dominik Lorenz](https://github.com/qp-qp)\, +[Patrick Esser](https://github.com/pesser), +[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+_[CVPR '22 Oral](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html) | +[GitHub](https://github.com/CompVis/latent-diffusion) | [arXiv](https://arxiv.org/abs/2112.10752) | [Project page](https://ommer-lab.com/research/latent-diffusion-models/)_ + +![txt2img-stable2](assets/stable-samples/txt2img/merged-0006.png) +[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion +model. +Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. +Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487), +this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. +With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM. +See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion). + + +### Requirements +A suitable [conda](https://conda.io/) environment named `ldm` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch torchvision -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` + + +### Stable Diffusion v1 + +Stable Diffusion v1 refers to a specific configuration of the model +architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet +and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and +then finetuned on 512x512 images. + +*Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present +in its training data. +Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](Stable_Diffusion_v1_Model_Card.md).* + +The weights are available via [the CompVis organization at Hugging Face](https://huggingface.co/CompVis) under [a license which contains specific use-based restrictions to prevent misuse and harm as informed by the model card, but otherwise remains permissive](LICENSE). While commercial use is permitted under the terms of the license, **we do not recommend using the provided weights for services or products without additional safety mechanisms and considerations**, since there are [known limitations and biases](Stable_Diffusion_v1_Model_Card.md#limitations-and-bias) of the weights, and research on safe and ethical deployment of general text-to-image models is an ongoing effort. **The weights are research artifacts and should be treated as such.** + +[The CreativeML OpenRAIL M license](LICENSE) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based. + +#### Weights + +We currently provide the following checkpoints: + +- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en). + 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`). +- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`. + 515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally + filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)). +- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). +- `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). + +Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, +5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling +steps show the relative improvements of the checkpoints: +![sd evaluation results](assets/v1-variants-scores.jpg) + + + +#### Text-to-Image with Stable Diffusion +![txt2img-stable2](assets/stable-samples/txt2img/merged-0005.png) +![txt2img-stable2](assets/stable-samples/txt2img/merged-0007.png) + +Stable Diffusion is a latent diffusion model conditioned on the (non-pooled) text embeddings of a CLIP ViT-L/14 text encoder. +We provide a [reference script for sampling](#reference-sampling-script), but +there also exists a [diffusers integration](#diffusers-integration), which we +expect to see more active community development. + +##### Reference Sampling Script + +We provide a reference sampling script, which incorporates + +- a [Safety Checker Module](https://github.com/CompVis/stable-diffusion/pull/36), + to reduce the probability of explicit outputs, +- an [invisible watermarking](https://github.com/ShieldMnt/invisible-watermark) + of the outputs, to help viewers [identify the images as machine-generated](scripts/tests/test_watermark.py). + +After [obtaining the `stable-diffusion-v1-*-original` weights](#weights), link them +``` +mkdir -p models/ldm/stable-diffusion-v1/ +ln -s models/ldm/stable-diffusion-v1/model.ckpt +``` +and sample with +``` +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms +``` + +By default, this uses a guidance scale of `--scale 7.5`, [Katherine Crowson's implementation](https://github.com/CompVis/latent-diffusion/pull/51) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler, +and renders images of size 512x512 (which it was trained on) in 50 steps. All supported arguments are listed below (type `python scripts/txt2img.py --help`). + + +```commandline +usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] + [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] + [--seed SEED] [--precision {full,autocast}] + +optional arguments: + -h, --help show this help message and exit + --prompt [PROMPT] the prompt to render + --outdir [OUTDIR] dir to write results to + --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples + --skip_save do not save individual samples. For speed measurements. + --ddim_steps DDIM_STEPS + number of ddim sampling steps + --plms use plms sampling + --laion400m uses the LAION400M model + --fixed_code if enabled, uses the same starting code across samples + --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling + --n_iter N_ITER sample this often + --H H image height, in pixel space + --W W image width, in pixel space + --C C latent channels + --f F downsampling factor + --n_samples N_SAMPLES + how many samples to produce for each given prompt. A.k.a. batch size + --n_rows N_ROWS rows in the grid (default: n_samples) + --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) + --from-file FROM_FILE + if specified, load prompts from this file + --config CONFIG path to config which constructs model + --ckpt CKPT path to checkpoint of model + --seed SEED the seed (for reproducible sampling) + --precision {full,autocast} + evaluate at this precision +``` +Note: The inference config for all v1 versions is designed to be used with EMA-only checkpoints. +For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from +non-EMA to EMA weights. If you want to examine the effect of EMA vs no EMA, we provide "full" checkpoints +which contain both types of weights. For these, `use_ema=False` will load and use the non-EMA weights. + + +##### Diffusers Integration + +A simple way to download and sample Stable Diffusion is by using the [diffusers library](https://github.com/huggingface/diffusers/tree/main#new--stable-diffusion-is-now-fully-compatible-with-diffusers): +```py +# make sure you're logged in with `huggingface-cli login` +from torch import autocast +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True +).to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +with autocast("cuda"): + image = pipe(prompt)["sample"][0] + +image.save("astronaut_rides_horse.png") +``` + + +#### Image Modification with Stable Diffusion + +By using a diffusion-denoising mechanism as first proposed by [SDEdit](https://arxiv.org/abs/2108.01073), the model can be used for different +tasks such as text-guided image-to-image translation and upscaling. Similar to the txt2img sampling script, +we provide a script to perform image modification with Stable Diffusion. + +The following describes an example where a rough sketch made in [Pinta](https://www.pinta-project.com/) is converted into a detailed artwork. +``` +python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8 +``` +Here, strength is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. +Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. See the following example. + +**Input** + +![sketch-in](assets/stable-samples/img2img/sketch-mountains-input.jpg) + +**Outputs** + +![out3](assets/stable-samples/img2img/mountains-3.png) +![out2](assets/stable-samples/img2img/mountains-2.png) + +This procedure can, for example, also be used to upscale samples from the base model. + +### Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + + +### BibTeX + +``` +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + diff --git a/stable-diffusion/Stable_Diffusion_v1_Model_Card.md b/stable-diffusion/Stable_Diffusion_v1_Model_Card.md new file mode 100644 index 0000000..ad76ad2 --- /dev/null +++ b/stable-diffusion/Stable_Diffusion_v1_Model_Card.md @@ -0,0 +1,144 @@ +# Stable Diffusion v1 Model Card +This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion). + +## Model Details +- **Developed by:** Robin Rombach, Patrick Esser +- **Model type:** Diffusion-based text-to-image generation model +- **Language(s):** English +- **License:** [Proprietary](LICENSE) +- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487). +- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752). +- **Cite as:** + + @InProceedings{Rombach_2022_CVPR, + author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, + title = {High-Resolution Image Synthesis With Latent Diffusion Models}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {10684-10695} + } + +# Uses + +## Direct Use +The model is intended for research purposes only. Possible research areas and +tasks include + +- Safe deployment of models which have the potential to generate harmful content. +- Probing and understanding the limitations and biases of generative models. +- Generation of artworks and use in design and other artistic processes. +- Applications in educational or creative tools. +- Research on generative models. + +Excluded uses are described below. + + ### Misuse, Malicious Use, and Out-of-Scope Use +_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_. + +The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +#### Out-of-Scope Use +The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. + +#### Misuse and Malicious Use +Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to: + +- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. +- Intentionally promoting or propagating discriminatory content or harmful stereotypes. +- Impersonating individuals without their consent. +- Sexual content without consent of the people who might see it. +- Mis- and disinformation +- Representations of egregious violence and gore +- Sharing of copyrighted or licensed material in violation of its terms of use. +- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. + +## Limitations and Bias + +### Limitations + +- The model does not achieve perfect photorealism +- The model cannot render legible text +- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere” +- Faces and people in general may not be generated properly. +- The model was trained mainly with English captions and will not work as well in other languages. +- The autoencoding part of the model is lossy +- The model was trained on a large-scale dataset + [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material + and is not fit for product use without additional safety mechanisms and + considerations. +- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data. + The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images. + +### Bias +While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. +Stable Diffusion v1 was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), +which consists of images that are limited to English descriptions. +Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. +This affects the overall output of the model, as white and western cultures are often set as the default. Further, the +ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts. +Stable Diffusion v1 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent. + + +## Training + +**Training Data** +The model developers used the following dataset for training the model: + +- LAION-5B and subsets thereof (see next section) + +**Training Procedure** +Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training, + +- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4 +- Text prompts are encoded through a ViT-L/14 text-encoder. +- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention. +- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. + +We currently provide the following checkpoints: + +- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en). + 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`). +- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`. + 515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally +filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)). +- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). +- `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). + +- **Hardware:** 32 x 8 x A100 GPUs +- **Optimizer:** AdamW +- **Gradient Accumulations**: 2 +- **Batch:** 32 x 8 x 2 x 4 = 2048 +- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant + +## Evaluation Results +Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, +5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling +steps show the relative improvements of the checkpoints: + +![pareto](assets/v1-variants-scores.jpg) + +Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores. + +## Environmental Impact + +**Stable Diffusion v1** **Estimated Emissions** +Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact. + +- **Hardware Type:** A100 PCIe 40GB +- **Hours used:** 150000 +- **Cloud Provider:** AWS +- **Compute Region:** US-east +- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq. + +## Citation + @InProceedings{Rombach_2022_CVPR, + author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, + title = {High-Resolution Image Synthesis With Latent Diffusion Models}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {10684-10695} + } + +*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).* diff --git a/stable-diffusion/assets/a-painting-of-a-fire.png b/stable-diffusion/assets/a-painting-of-a-fire.png new file mode 100644 index 0000000..3d3b9bd Binary files /dev/null and b/stable-diffusion/assets/a-painting-of-a-fire.png differ diff --git a/stable-diffusion/assets/a-photograph-of-a-fire.png b/stable-diffusion/assets/a-photograph-of-a-fire.png new file mode 100644 index 0000000..e246bc1 Binary files /dev/null and b/stable-diffusion/assets/a-photograph-of-a-fire.png differ diff --git a/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png b/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png new file mode 100644 index 0000000..aa68f18 Binary files /dev/null and b/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png differ diff --git a/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png b/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png new file mode 100644 index 0000000..f058b97 Binary files /dev/null and b/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png differ diff --git a/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png b/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png new file mode 100644 index 0000000..e4ebe13 Binary files /dev/null and b/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png differ diff --git a/stable-diffusion/assets/birdhouse.png b/stable-diffusion/assets/birdhouse.png new file mode 100644 index 0000000..872d49c Binary files /dev/null and b/stable-diffusion/assets/birdhouse.png differ diff --git a/stable-diffusion/assets/fire.png b/stable-diffusion/assets/fire.png new file mode 100644 index 0000000..64c24fe Binary files /dev/null and b/stable-diffusion/assets/fire.png differ diff --git a/stable-diffusion/assets/inpainting.png b/stable-diffusion/assets/inpainting.png new file mode 100644 index 0000000..d6b9ef8 Binary files /dev/null and b/stable-diffusion/assets/inpainting.png differ diff --git a/stable-diffusion/assets/modelfigure.png b/stable-diffusion/assets/modelfigure.png new file mode 100644 index 0000000..6b1d3e6 Binary files /dev/null and b/stable-diffusion/assets/modelfigure.png differ diff --git a/stable-diffusion/assets/rdm-preview.jpg b/stable-diffusion/assets/rdm-preview.jpg new file mode 100644 index 0000000..3838b0f Binary files /dev/null and b/stable-diffusion/assets/rdm-preview.jpg differ diff --git a/stable-diffusion/assets/reconstruction1.png b/stable-diffusion/assets/reconstruction1.png new file mode 100644 index 0000000..0752799 Binary files /dev/null and b/stable-diffusion/assets/reconstruction1.png differ diff --git a/stable-diffusion/assets/reconstruction2.png b/stable-diffusion/assets/reconstruction2.png new file mode 100644 index 0000000..b8e7a36 Binary files /dev/null and b/stable-diffusion/assets/reconstruction2.png differ diff --git a/stable-diffusion/assets/results.gif b/stable-diffusion/assets/results.gif new file mode 100644 index 0000000..82b6590 Binary files /dev/null and b/stable-diffusion/assets/results.gif differ diff --git a/stable-diffusion/assets/rick.jpeg b/stable-diffusion/assets/rick.jpeg new file mode 100644 index 0000000..9954860 Binary files /dev/null and b/stable-diffusion/assets/rick.jpeg differ diff --git a/stable-diffusion/assets/stable-samples/img2img/mountains-1.png b/stable-diffusion/assets/stable-samples/img2img/mountains-1.png new file mode 100644 index 0000000..d01b835 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/mountains-1.png differ diff --git a/stable-diffusion/assets/stable-samples/img2img/mountains-2.png b/stable-diffusion/assets/stable-samples/img2img/mountains-2.png new file mode 100644 index 0000000..e9f4e70 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/mountains-2.png differ diff --git a/stable-diffusion/assets/stable-samples/img2img/mountains-3.png b/stable-diffusion/assets/stable-samples/img2img/mountains-3.png new file mode 100644 index 0000000..017de30 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/mountains-3.png differ diff --git a/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg b/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg new file mode 100644 index 0000000..79d652b Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg differ diff --git a/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png b/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png new file mode 100644 index 0000000..501c31c Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png differ diff --git a/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png b/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png new file mode 100644 index 0000000..1c4bb25 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png differ diff --git a/stable-diffusion/assets/stable-samples/txt2img/000002025.png b/stable-diffusion/assets/stable-samples/txt2img/000002025.png new file mode 100644 index 0000000..66891c1 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/txt2img/000002025.png differ diff --git a/stable-diffusion/assets/stable-samples/txt2img/000002035.png b/stable-diffusion/assets/stable-samples/txt2img/000002035.png new file mode 100644 index 0000000..c707c13 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/txt2img/000002035.png differ diff --git a/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png b/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png new file mode 100644 index 0000000..ca0a1af Binary files /dev/null and b/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png differ diff --git a/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png b/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png new file mode 100644 index 0000000..999f370 Binary files /dev/null and b/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png differ diff --git a/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png b/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png new file mode 100644 index 0000000..af390ac Binary files /dev/null and b/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png differ diff --git a/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png b/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png new file mode 100644 index 0000000..9079720 Binary files /dev/null and b/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png differ diff --git a/stable-diffusion/assets/txt2img-convsample.png b/stable-diffusion/assets/txt2img-convsample.png new file mode 100644 index 0000000..255c265 Binary files /dev/null and b/stable-diffusion/assets/txt2img-convsample.png differ diff --git a/stable-diffusion/assets/txt2img-preview.png b/stable-diffusion/assets/txt2img-preview.png new file mode 100644 index 0000000..51ee1c2 Binary files /dev/null and b/stable-diffusion/assets/txt2img-preview.png differ diff --git a/stable-diffusion/assets/v1-variants-scores.jpg b/stable-diffusion/assets/v1-variants-scores.jpg new file mode 100644 index 0000000..7d997ba Binary files /dev/null and b/stable-diffusion/assets/v1-variants-scores.jpg differ diff --git a/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml b/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml new file mode 100644 index 0000000..5f1d10e --- /dev/null +++ b/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml @@ -0,0 +1,54 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 16 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml b/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml new file mode 100644 index 0000000..ab8b36f --- /dev/null +++ b/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml @@ -0,0 +1,53 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 4 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml b/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml new file mode 100644 index 0000000..5e3db5c --- /dev/null +++ b/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml @@ -0,0 +1,54 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 3 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml b/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml new file mode 100644 index 0000000..5ccd09d --- /dev/null +++ b/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml @@ -0,0 +1,53 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 64 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 64 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16,8] + dropout: 0.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml b/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml new file mode 100644 index 0000000..89b3df4 --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml @@ -0,0 +1,86 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: models/first_stage_models/vq-f4/model.ckpt + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.CelebAHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.CelebAHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml b/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml new file mode 100644 index 0000000..b8cd9e2 --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml @@ -0,0 +1,98 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: + #note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 32 for f8 + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 4 + n_embed: 16384 + ckpt_path: configs/first_stage_models/vq-f8/model.yaml + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 512 + key: class_label +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + num_workers: 12 + wrap: false + train: + target: ldm.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: ldm.data.imagenet.ImageNetValidation + params: + config: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml b/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml new file mode 100644 index 0000000..b7c1aa2 --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 0.0001 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 192 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 5 + num_heads: 1 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + n_classes: 1001 + embed_dim: 512 + key: class_label diff --git a/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml b/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml new file mode 100644 index 0000000..1899e30 --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 42 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.FFHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.FFHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml b/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml new file mode 100644 index 0000000..c4ca66c --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNBedroomsTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNBedroomsValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml b/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml new file mode 100644 index 0000000..18dc8c2 --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml @@ -0,0 +1,91 @@ +model: + base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + loss_type: l1 + first_stage_key: "image" + cond_stage_key: "image" + image_size: 32 + channels: 4 + cond_stage_trainable: False + concat_mode: False + scale_by_std: True + monitor: 'val/loss_simple_ema' + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [10000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [ 1.] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 192 + attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 + num_res_blocks: 2 + channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 + num_heads: 8 + use_scale_shift_norm: True + resblock_updown: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: "val/rec_loss" + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: "__is_unconditional__" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 96 + num_workers: 5 + wrap: False + train: + target: ldm.data.lsun.LSUNChurchesTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNChurchesValidation + params: + size: 256 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + + trainer: + benchmark: True \ No newline at end of file diff --git a/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml b/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml new file mode 100644 index 0000000..8e331cb --- /dev/null +++ b/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml @@ -0,0 +1,71 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 1280 + use_checkpoint: true + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 1280 + n_layer: 32 diff --git a/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml b/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml new file mode 100644 index 0000000..b51b1d8 --- /dev/null +++ b/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 0.0001 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: jpg + cond_stage_key: nix + image_size: 48 + channels: 16 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_by_std: false + scale_factor: 0.22765929 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 48 + in_channels: 16 + out_channels: 16 + model_channels: 448 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + use_scale_shift_norm: false + resblock_updown: false + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: true + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 16 + ddconfig: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: torch.nn.Identity \ No newline at end of file diff --git a/stable-diffusion/configs/stable-diffusion/v1-inference.yaml b/stable-diffusion/configs/stable-diffusion/v1-inference.yaml new file mode 100644 index 0000000..d4effe5 --- /dev/null +++ b/stable-diffusion/configs/stable-diffusion/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/stable-diffusion/data/DejaVuSans.ttf b/stable-diffusion/data/DejaVuSans.ttf new file mode 100644 index 0000000..e5f7eec Binary files /dev/null and b/stable-diffusion/data/DejaVuSans.ttf differ diff --git a/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg b/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg new file mode 100644 index 0000000..09abe80 Binary files /dev/null and b/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg differ diff --git a/stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt b/stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt new file mode 100644 index 0000000..de60c5c --- /dev/null +++ b/stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt @@ -0,0 +1 @@ +A basket of cerries diff --git a/stable-diffusion/data/imagenet_clsidx_to_label.txt b/stable-diffusion/data/imagenet_clsidx_to_label.txt new file mode 100644 index 0000000..e2fe435 --- /dev/null +++ b/stable-diffusion/data/imagenet_clsidx_to_label.txt @@ -0,0 +1,1000 @@ + 0: 'tench, Tinca tinca', + 1: 'goldfish, Carassius auratus', + 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', + 3: 'tiger shark, Galeocerdo cuvieri', + 4: 'hammerhead, hammerhead shark', + 5: 'electric ray, crampfish, numbfish, torpedo', + 6: 'stingray', + 7: 'cock', + 8: 'hen', + 9: 'ostrich, Struthio camelus', + 10: 'brambling, Fringilla montifringilla', + 11: 'goldfinch, Carduelis carduelis', + 12: 'house finch, linnet, Carpodacus mexicanus', + 13: 'junco, snowbird', + 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 15: 'robin, American robin, Turdus migratorius', + 16: 'bulbul', + 17: 'jay', + 18: 'magpie', + 19: 'chickadee', + 20: 'water ouzel, dipper', + 21: 'kite', + 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', + 23: 'vulture', + 24: 'great grey owl, great gray owl, Strix nebulosa', + 25: 'European fire salamander, Salamandra salamandra', + 26: 'common newt, Triturus vulgaris', + 27: 'eft', + 28: 'spotted salamander, Ambystoma maculatum', + 29: 'axolotl, mud puppy, Ambystoma mexicanum', + 30: 'bullfrog, Rana catesbeiana', + 31: 'tree frog, tree-frog', + 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 33: 'loggerhead, loggerhead turtle, Caretta caretta', + 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', + 35: 'mud turtle', + 36: 'terrapin', + 37: 'box turtle, box tortoise', + 38: 'banded gecko', + 39: 'common iguana, iguana, Iguana iguana', + 40: 'American chameleon, anole, Anolis carolinensis', + 41: 'whiptail, whiptail lizard', + 42: 'agama', + 43: 'frilled lizard, Chlamydosaurus kingi', + 44: 'alligator lizard', + 45: 'Gila monster, Heloderma suspectum', + 46: 'green lizard, Lacerta viridis', + 47: 'African chameleon, Chamaeleo chamaeleon', + 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', + 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', + 50: 'American alligator, Alligator mississipiensis', + 51: 'triceratops', + 52: 'thunder snake, worm snake, Carphophis amoenus', + 53: 'ringneck snake, ring-necked snake, ring snake', + 54: 'hognose snake, puff adder, sand viper', + 55: 'green snake, grass snake', + 56: 'king snake, kingsnake', + 57: 'garter snake, grass snake', + 58: 'water snake', + 59: 'vine snake', + 60: 'night snake, Hypsiglena torquata', + 61: 'boa constrictor, Constrictor constrictor', + 62: 'rock python, rock snake, Python sebae', + 63: 'Indian cobra, Naja naja', + 64: 'green mamba', + 65: 'sea snake', + 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', + 69: 'trilobite', + 70: 'harvestman, daddy longlegs, Phalangium opilio', + 71: 'scorpion', + 72: 'black and gold garden spider, Argiope aurantia', + 73: 'barn spider, Araneus cavaticus', + 74: 'garden spider, Aranea diademata', + 75: 'black widow, Latrodectus mactans', + 76: 'tarantula', + 77: 'wolf spider, hunting spider', + 78: 'tick', + 79: 'centipede', + 80: 'black grouse', + 81: 'ptarmigan', + 82: 'ruffed grouse, partridge, Bonasa umbellus', + 83: 'prairie chicken, prairie grouse, prairie fowl', + 84: 'peacock', + 85: 'quail', + 86: 'partridge', + 87: 'African grey, African gray, Psittacus erithacus', + 88: 'macaw', + 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 90: 'lorikeet', + 91: 'coucal', + 92: 'bee eater', + 93: 'hornbill', + 94: 'hummingbird', + 95: 'jacamar', + 96: 'toucan', + 97: 'drake', + 98: 'red-breasted merganser, Mergus serrator', + 99: 'goose', + 100: 'black swan, Cygnus atratus', + 101: 'tusker', + 102: 'echidna, spiny anteater, anteater', + 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', + 104: 'wallaby, brush kangaroo', + 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', + 106: 'wombat', + 107: 'jellyfish', + 108: 'sea anemone, anemone', + 109: 'brain coral', + 110: 'flatworm, platyhelminth', + 111: 'nematode, nematode worm, roundworm', + 112: 'conch', + 113: 'snail', + 114: 'slug', + 115: 'sea slug, nudibranch', + 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 117: 'chambered nautilus, pearly nautilus, nautilus', + 118: 'Dungeness crab, Cancer magister', + 119: 'rock crab, Cancer irroratus', + 120: 'fiddler crab', + 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', + 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', + 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', + 124: 'crayfish, crawfish, crawdad, crawdaddy', + 125: 'hermit crab', + 126: 'isopod', + 127: 'white stork, Ciconia ciconia', + 128: 'black stork, Ciconia nigra', + 129: 'spoonbill', + 130: 'flamingo', + 131: 'little blue heron, Egretta caerulea', + 132: 'American egret, great white heron, Egretta albus', + 133: 'bittern', + 134: 'crane', + 135: 'limpkin, Aramus pictus', + 136: 'European gallinule, Porphyrio porphyrio', + 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 138: 'bustard', + 139: 'ruddy turnstone, Arenaria interpres', + 140: 'red-backed sandpiper, dunlin, Erolia alpina', + 141: 'redshank, Tringa totanus', + 142: 'dowitcher', + 143: 'oystercatcher, oyster catcher', + 144: 'pelican', + 145: 'king penguin, Aptenodytes patagonica', + 146: 'albatross, mollymawk', + 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', + 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 149: 'dugong, Dugong dugon', + 150: 'sea lion', + 151: 'Chihuahua', + 152: 'Japanese spaniel', + 153: 'Maltese dog, Maltese terrier, Maltese', + 154: 'Pekinese, Pekingese, Peke', + 155: 'Shih-Tzu', + 156: 'Blenheim spaniel', + 157: 'papillon', + 158: 'toy terrier', + 159: 'Rhodesian ridgeback', + 160: 'Afghan hound, Afghan', + 161: 'basset, basset hound', + 162: 'beagle', + 163: 'bloodhound, sleuthhound', + 164: 'bluetick', + 165: 'black-and-tan coonhound', + 166: 'Walker hound, Walker foxhound', + 167: 'English foxhound', + 168: 'redbone', + 169: 'borzoi, Russian wolfhound', + 170: 'Irish wolfhound', + 171: 'Italian greyhound', + 172: 'whippet', + 173: 'Ibizan hound, Ibizan Podenco', + 174: 'Norwegian elkhound, elkhound', + 175: 'otterhound, otter hound', + 176: 'Saluki, gazelle hound', + 177: 'Scottish deerhound, deerhound', + 178: 'Weimaraner', + 179: 'Staffordshire bullterrier, Staffordshire bull terrier', + 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', + 181: 'Bedlington terrier', + 182: 'Border terrier', + 183: 'Kerry blue terrier', + 184: 'Irish terrier', + 185: 'Norfolk terrier', + 186: 'Norwich terrier', + 187: 'Yorkshire terrier', + 188: 'wire-haired fox terrier', + 189: 'Lakeland terrier', + 190: 'Sealyham terrier, Sealyham', + 191: 'Airedale, Airedale terrier', + 192: 'cairn, cairn terrier', + 193: 'Australian terrier', + 194: 'Dandie Dinmont, Dandie Dinmont terrier', + 195: 'Boston bull, Boston terrier', + 196: 'miniature schnauzer', + 197: 'giant schnauzer', + 198: 'standard schnauzer', + 199: 'Scotch terrier, Scottish terrier, Scottie', + 200: 'Tibetan terrier, chrysanthemum dog', + 201: 'silky terrier, Sydney silky', + 202: 'soft-coated wheaten terrier', + 203: 'West Highland white terrier', + 204: 'Lhasa, Lhasa apso', + 205: 'flat-coated retriever', + 206: 'curly-coated retriever', + 207: 'golden retriever', + 208: 'Labrador retriever', + 209: 'Chesapeake Bay retriever', + 210: 'German short-haired pointer', + 211: 'vizsla, Hungarian pointer', + 212: 'English setter', + 213: 'Irish setter, red setter', + 214: 'Gordon setter', + 215: 'Brittany spaniel', + 216: 'clumber, clumber spaniel', + 217: 'English springer, English springer spaniel', + 218: 'Welsh springer spaniel', + 219: 'cocker spaniel, English cocker spaniel, cocker', + 220: 'Sussex spaniel', + 221: 'Irish water spaniel', + 222: 'kuvasz', + 223: 'schipperke', + 224: 'groenendael', + 225: 'malinois', + 226: 'briard', + 227: 'kelpie', + 228: 'komondor', + 229: 'Old English sheepdog, bobtail', + 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', + 231: 'collie', + 232: 'Border collie', + 233: 'Bouvier des Flandres, Bouviers des Flandres', + 234: 'Rottweiler', + 235: 'German shepherd, German shepherd dog, German police dog, alsatian', + 236: 'Doberman, Doberman pinscher', + 237: 'miniature pinscher', + 238: 'Greater Swiss Mountain dog', + 239: 'Bernese mountain dog', + 240: 'Appenzeller', + 241: 'EntleBucher', + 242: 'boxer', + 243: 'bull mastiff', + 244: 'Tibetan mastiff', + 245: 'French bulldog', + 246: 'Great Dane', + 247: 'Saint Bernard, St Bernard', + 248: 'Eskimo dog, husky', + 249: 'malamute, malemute, Alaskan malamute', + 250: 'Siberian husky', + 251: 'dalmatian, coach dog, carriage dog', + 252: 'affenpinscher, monkey pinscher, monkey dog', + 253: 'basenji', + 254: 'pug, pug-dog', + 255: 'Leonberg', + 256: 'Newfoundland, Newfoundland dog', + 257: 'Great Pyrenees', + 258: 'Samoyed, Samoyede', + 259: 'Pomeranian', + 260: 'chow, chow chow', + 261: 'keeshond', + 262: 'Brabancon griffon', + 263: 'Pembroke, Pembroke Welsh corgi', + 264: 'Cardigan, Cardigan Welsh corgi', + 265: 'toy poodle', + 266: 'miniature poodle', + 267: 'standard poodle', + 268: 'Mexican hairless', + 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', + 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', + 271: 'red wolf, maned wolf, Canis rufus, Canis niger', + 272: 'coyote, prairie wolf, brush wolf, Canis latrans', + 273: 'dingo, warrigal, warragal, Canis dingo', + 274: 'dhole, Cuon alpinus', + 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 276: 'hyena, hyaena', + 277: 'red fox, Vulpes vulpes', + 278: 'kit fox, Vulpes macrotis', + 279: 'Arctic fox, white fox, Alopex lagopus', + 280: 'grey fox, gray fox, Urocyon cinereoargenteus', + 281: 'tabby, tabby cat', + 282: 'tiger cat', + 283: 'Persian cat', + 284: 'Siamese cat, Siamese', + 285: 'Egyptian cat', + 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', + 287: 'lynx, catamount', + 288: 'leopard, Panthera pardus', + 289: 'snow leopard, ounce, Panthera uncia', + 290: 'jaguar, panther, Panthera onca, Felis onca', + 291: 'lion, king of beasts, Panthera leo', + 292: 'tiger, Panthera tigris', + 293: 'cheetah, chetah, Acinonyx jubatus', + 294: 'brown bear, bruin, Ursus arctos', + 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', + 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 297: 'sloth bear, Melursus ursinus, Ursus ursinus', + 298: 'mongoose', + 299: 'meerkat, mierkat', + 300: 'tiger beetle', + 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 302: 'ground beetle, carabid beetle', + 303: 'long-horned beetle, longicorn, longicorn beetle', + 304: 'leaf beetle, chrysomelid', + 305: 'dung beetle', + 306: 'rhinoceros beetle', + 307: 'weevil', + 308: 'fly', + 309: 'bee', + 310: 'ant, emmet, pismire', + 311: 'grasshopper, hopper', + 312: 'cricket', + 313: 'walking stick, walkingstick, stick insect', + 314: 'cockroach, roach', + 315: 'mantis, mantid', + 316: 'cicada, cicala', + 317: 'leafhopper', + 318: 'lacewing, lacewing fly', + 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + 320: 'damselfly', + 321: 'admiral', + 322: 'ringlet, ringlet butterfly', + 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 324: 'cabbage butterfly', + 325: 'sulphur butterfly, sulfur butterfly', + 326: 'lycaenid, lycaenid butterfly', + 327: 'starfish, sea star', + 328: 'sea urchin', + 329: 'sea cucumber, holothurian', + 330: 'wood rabbit, cottontail, cottontail rabbit', + 331: 'hare', + 332: 'Angora, Angora rabbit', + 333: 'hamster', + 334: 'porcupine, hedgehog', + 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', + 336: 'marmot', + 337: 'beaver', + 338: 'guinea pig, Cavia cobaya', + 339: 'sorrel', + 340: 'zebra', + 341: 'hog, pig, grunter, squealer, Sus scrofa', + 342: 'wild boar, boar, Sus scrofa', + 343: 'warthog', + 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 345: 'ox', + 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 347: 'bison', + 348: 'ram, tup', + 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', + 350: 'ibex, Capra ibex', + 351: 'hartebeest', + 352: 'impala, Aepyceros melampus', + 353: 'gazelle', + 354: 'Arabian camel, dromedary, Camelus dromedarius', + 355: 'llama', + 356: 'weasel', + 357: 'mink', + 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', + 359: 'black-footed ferret, ferret, Mustela nigripes', + 360: 'otter', + 361: 'skunk, polecat, wood pussy', + 362: 'badger', + 363: 'armadillo', + 364: 'three-toed sloth, ai, Bradypus tridactylus', + 365: 'orangutan, orang, orangutang, Pongo pygmaeus', + 366: 'gorilla, Gorilla gorilla', + 367: 'chimpanzee, chimp, Pan troglodytes', + 368: 'gibbon, Hylobates lar', + 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 370: 'guenon, guenon monkey', + 371: 'patas, hussar monkey, Erythrocebus patas', + 372: 'baboon', + 373: 'macaque', + 374: 'langur', + 375: 'colobus, colobus monkey', + 376: 'proboscis monkey, Nasalis larvatus', + 377: 'marmoset', + 378: 'capuchin, ringtail, Cebus capucinus', + 379: 'howler monkey, howler', + 380: 'titi, titi monkey', + 381: 'spider monkey, Ateles geoffroyi', + 382: 'squirrel monkey, Saimiri sciureus', + 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', + 384: 'indri, indris, Indri indri, Indri brevicaudatus', + 385: 'Indian elephant, Elephas maximus', + 386: 'African elephant, Loxodonta africana', + 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 389: 'barracouta, snoek', + 390: 'eel', + 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', + 392: 'rock beauty, Holocanthus tricolor', + 393: 'anemone fish', + 394: 'sturgeon', + 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 396: 'lionfish', + 397: 'puffer, pufferfish, blowfish, globefish', + 398: 'abacus', + 399: 'abaya', + 400: "academic gown, academic robe, judge's robe", + 401: 'accordion, piano accordion, squeeze box', + 402: 'acoustic guitar', + 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 404: 'airliner', + 405: 'airship, dirigible', + 406: 'altar', + 407: 'ambulance', + 408: 'amphibian, amphibious vehicle', + 409: 'analog clock', + 410: 'apiary, bee house', + 411: 'apron', + 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', + 413: 'assault rifle, assault gun', + 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 415: 'bakery, bakeshop, bakehouse', + 416: 'balance beam, beam', + 417: 'balloon', + 418: 'ballpoint, ballpoint pen, ballpen, Biro', + 419: 'Band Aid', + 420: 'banjo', + 421: 'bannister, banister, balustrade, balusters, handrail', + 422: 'barbell', + 423: 'barber chair', + 424: 'barbershop', + 425: 'barn', + 426: 'barometer', + 427: 'barrel, cask', + 428: 'barrow, garden cart, lawn cart, wheelbarrow', + 429: 'baseball', + 430: 'basketball', + 431: 'bassinet', + 432: 'bassoon', + 433: 'bathing cap, swimming cap', + 434: 'bath towel', + 435: 'bathtub, bathing tub, bath, tub', + 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', + 437: 'beacon, lighthouse, beacon light, pharos', + 438: 'beaker', + 439: 'bearskin, busby, shako', + 440: 'beer bottle', + 441: 'beer glass', + 442: 'bell cote, bell cot', + 443: 'bib', + 444: 'bicycle-built-for-two, tandem bicycle, tandem', + 445: 'bikini, two-piece', + 446: 'binder, ring-binder', + 447: 'binoculars, field glasses, opera glasses', + 448: 'birdhouse', + 449: 'boathouse', + 450: 'bobsled, bobsleigh, bob', + 451: 'bolo tie, bolo, bola tie, bola', + 452: 'bonnet, poke bonnet', + 453: 'bookcase', + 454: 'bookshop, bookstore, bookstall', + 455: 'bottlecap', + 456: 'bow', + 457: 'bow tie, bow-tie, bowtie', + 458: 'brass, memorial tablet, plaque', + 459: 'brassiere, bra, bandeau', + 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 461: 'breastplate, aegis, egis', + 462: 'broom', + 463: 'bucket, pail', + 464: 'buckle', + 465: 'bulletproof vest', + 466: 'bullet train, bullet', + 467: 'butcher shop, meat market', + 468: 'cab, hack, taxi, taxicab', + 469: 'caldron, cauldron', + 470: 'candle, taper, wax light', + 471: 'cannon', + 472: 'canoe', + 473: 'can opener, tin opener', + 474: 'cardigan', + 475: 'car mirror', + 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', + 477: "carpenter's kit, tool kit", + 478: 'carton', + 479: 'car wheel', + 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', + 481: 'cassette', + 482: 'cassette player', + 483: 'castle', + 484: 'catamaran', + 485: 'CD player', + 486: 'cello, violoncello', + 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 488: 'chain', + 489: 'chainlink fence', + 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', + 491: 'chain saw, chainsaw', + 492: 'chest', + 493: 'chiffonier, commode', + 494: 'chime, bell, gong', + 495: 'china cabinet, china closet', + 496: 'Christmas stocking', + 497: 'church, church building', + 498: 'cinema, movie theater, movie theatre, movie house, picture palace', + 499: 'cleaver, meat cleaver, chopper', + 500: 'cliff dwelling', + 501: 'cloak', + 502: 'clog, geta, patten, sabot', + 503: 'cocktail shaker', + 504: 'coffee mug', + 505: 'coffeepot', + 506: 'coil, spiral, volute, whorl, helix', + 507: 'combination lock', + 508: 'computer keyboard, keypad', + 509: 'confectionery, confectionary, candy store', + 510: 'container ship, containership, container vessel', + 511: 'convertible', + 512: 'corkscrew, bottle screw', + 513: 'cornet, horn, trumpet, trump', + 514: 'cowboy boot', + 515: 'cowboy hat, ten-gallon hat', + 516: 'cradle', + 517: 'crane', + 518: 'crash helmet', + 519: 'crate', + 520: 'crib, cot', + 521: 'Crock Pot', + 522: 'croquet ball', + 523: 'crutch', + 524: 'cuirass', + 525: 'dam, dike, dyke', + 526: 'desk', + 527: 'desktop computer', + 528: 'dial telephone, dial phone', + 529: 'diaper, nappy, napkin', + 530: 'digital clock', + 531: 'digital watch', + 532: 'dining table, board', + 533: 'dishrag, dishcloth', + 534: 'dishwasher, dish washer, dishwashing machine', + 535: 'disk brake, disc brake', + 536: 'dock, dockage, docking facility', + 537: 'dogsled, dog sled, dog sleigh', + 538: 'dome', + 539: 'doormat, welcome mat', + 540: 'drilling platform, offshore rig', + 541: 'drum, membranophone, tympan', + 542: 'drumstick', + 543: 'dumbbell', + 544: 'Dutch oven', + 545: 'electric fan, blower', + 546: 'electric guitar', + 547: 'electric locomotive', + 548: 'entertainment center', + 549: 'envelope', + 550: 'espresso maker', + 551: 'face powder', + 552: 'feather boa, boa', + 553: 'file, file cabinet, filing cabinet', + 554: 'fireboat', + 555: 'fire engine, fire truck', + 556: 'fire screen, fireguard', + 557: 'flagpole, flagstaff', + 558: 'flute, transverse flute', + 559: 'folding chair', + 560: 'football helmet', + 561: 'forklift', + 562: 'fountain', + 563: 'fountain pen', + 564: 'four-poster', + 565: 'freight car', + 566: 'French horn, horn', + 567: 'frying pan, frypan, skillet', + 568: 'fur coat', + 569: 'garbage truck, dustcart', + 570: 'gasmask, respirator, gas helmet', + 571: 'gas pump, gasoline pump, petrol pump, island dispenser', + 572: 'goblet', + 573: 'go-kart', + 574: 'golf ball', + 575: 'golfcart, golf cart', + 576: 'gondola', + 577: 'gong, tam-tam', + 578: 'gown', + 579: 'grand piano, grand', + 580: 'greenhouse, nursery, glasshouse', + 581: 'grille, radiator grille', + 582: 'grocery store, grocery, food market, market', + 583: 'guillotine', + 584: 'hair slide', + 585: 'hair spray', + 586: 'half track', + 587: 'hammer', + 588: 'hamper', + 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 590: 'hand-held computer, hand-held microcomputer', + 591: 'handkerchief, hankie, hanky, hankey', + 592: 'hard disc, hard disk, fixed disk', + 593: 'harmonica, mouth organ, harp, mouth harp', + 594: 'harp', + 595: 'harvester, reaper', + 596: 'hatchet', + 597: 'holster', + 598: 'home theater, home theatre', + 599: 'honeycomb', + 600: 'hook, claw', + 601: 'hoopskirt, crinoline', + 602: 'horizontal bar, high bar', + 603: 'horse cart, horse-cart', + 604: 'hourglass', + 605: 'iPod', + 606: 'iron, smoothing iron', + 607: "jack-o'-lantern", + 608: 'jean, blue jean, denim', + 609: 'jeep, landrover', + 610: 'jersey, T-shirt, tee shirt', + 611: 'jigsaw puzzle', + 612: 'jinrikisha, ricksha, rickshaw', + 613: 'joystick', + 614: 'kimono', + 615: 'knee pad', + 616: 'knot', + 617: 'lab coat, laboratory coat', + 618: 'ladle', + 619: 'lampshade, lamp shade', + 620: 'laptop, laptop computer', + 621: 'lawn mower, mower', + 622: 'lens cap, lens cover', + 623: 'letter opener, paper knife, paperknife', + 624: 'library', + 625: 'lifeboat', + 626: 'lighter, light, igniter, ignitor', + 627: 'limousine, limo', + 628: 'liner, ocean liner', + 629: 'lipstick, lip rouge', + 630: 'Loafer', + 631: 'lotion', + 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', + 633: "loupe, jeweler's loupe", + 634: 'lumbermill, sawmill', + 635: 'magnetic compass', + 636: 'mailbag, postbag', + 637: 'mailbox, letter box', + 638: 'maillot', + 639: 'maillot, tank suit', + 640: 'manhole cover', + 641: 'maraca', + 642: 'marimba, xylophone', + 643: 'mask', + 644: 'matchstick', + 645: 'maypole', + 646: 'maze, labyrinth', + 647: 'measuring cup', + 648: 'medicine chest, medicine cabinet', + 649: 'megalith, megalithic structure', + 650: 'microphone, mike', + 651: 'microwave, microwave oven', + 652: 'military uniform', + 653: 'milk can', + 654: 'minibus', + 655: 'miniskirt, mini', + 656: 'minivan', + 657: 'missile', + 658: 'mitten', + 659: 'mixing bowl', + 660: 'mobile home, manufactured home', + 661: 'Model T', + 662: 'modem', + 663: 'monastery', + 664: 'monitor', + 665: 'moped', + 666: 'mortar', + 667: 'mortarboard', + 668: 'mosque', + 669: 'mosquito net', + 670: 'motor scooter, scooter', + 671: 'mountain bike, all-terrain bike, off-roader', + 672: 'mountain tent', + 673: 'mouse, computer mouse', + 674: 'mousetrap', + 675: 'moving van', + 676: 'muzzle', + 677: 'nail', + 678: 'neck brace', + 679: 'necklace', + 680: 'nipple', + 681: 'notebook, notebook computer', + 682: 'obelisk', + 683: 'oboe, hautboy, hautbois', + 684: 'ocarina, sweet potato', + 685: 'odometer, hodometer, mileometer, milometer', + 686: 'oil filter', + 687: 'organ, pipe organ', + 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 689: 'overskirt', + 690: 'oxcart', + 691: 'oxygen mask', + 692: 'packet', + 693: 'paddle, boat paddle', + 694: 'paddlewheel, paddle wheel', + 695: 'padlock', + 696: 'paintbrush', + 697: "pajama, pyjama, pj's, jammies", + 698: 'palace', + 699: 'panpipe, pandean pipe, syrinx', + 700: 'paper towel', + 701: 'parachute, chute', + 702: 'parallel bars, bars', + 703: 'park bench', + 704: 'parking meter', + 705: 'passenger car, coach, carriage', + 706: 'patio, terrace', + 707: 'pay-phone, pay-station', + 708: 'pedestal, plinth, footstall', + 709: 'pencil box, pencil case', + 710: 'pencil sharpener', + 711: 'perfume, essence', + 712: 'Petri dish', + 713: 'photocopier', + 714: 'pick, plectrum, plectron', + 715: 'pickelhaube', + 716: 'picket fence, paling', + 717: 'pickup, pickup truck', + 718: 'pier', + 719: 'piggy bank, penny bank', + 720: 'pill bottle', + 721: 'pillow', + 722: 'ping-pong ball', + 723: 'pinwheel', + 724: 'pirate, pirate ship', + 725: 'pitcher, ewer', + 726: "plane, carpenter's plane, woodworking plane", + 727: 'planetarium', + 728: 'plastic bag', + 729: 'plate rack', + 730: 'plow, plough', + 731: "plunger, plumber's helper", + 732: 'Polaroid camera, Polaroid Land camera', + 733: 'pole', + 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', + 735: 'poncho', + 736: 'pool table, billiard table, snooker table', + 737: 'pop bottle, soda bottle', + 738: 'pot, flowerpot', + 739: "potter's wheel", + 740: 'power drill', + 741: 'prayer rug, prayer mat', + 742: 'printer', + 743: 'prison, prison house', + 744: 'projectile, missile', + 745: 'projector', + 746: 'puck, hockey puck', + 747: 'punching bag, punch bag, punching ball, punchball', + 748: 'purse', + 749: 'quill, quill pen', + 750: 'quilt, comforter, comfort, puff', + 751: 'racer, race car, racing car', + 752: 'racket, racquet', + 753: 'radiator', + 754: 'radio, wireless', + 755: 'radio telescope, radio reflector', + 756: 'rain barrel', + 757: 'recreational vehicle, RV, R.V.', + 758: 'reel', + 759: 'reflex camera', + 760: 'refrigerator, icebox', + 761: 'remote control, remote', + 762: 'restaurant, eating house, eating place, eatery', + 763: 'revolver, six-gun, six-shooter', + 764: 'rifle', + 765: 'rocking chair, rocker', + 766: 'rotisserie', + 767: 'rubber eraser, rubber, pencil eraser', + 768: 'rugby ball', + 769: 'rule, ruler', + 770: 'running shoe', + 771: 'safe', + 772: 'safety pin', + 773: 'saltshaker, salt shaker', + 774: 'sandal', + 775: 'sarong', + 776: 'sax, saxophone', + 777: 'scabbard', + 778: 'scale, weighing machine', + 779: 'school bus', + 780: 'schooner', + 781: 'scoreboard', + 782: 'screen, CRT screen', + 783: 'screw', + 784: 'screwdriver', + 785: 'seat belt, seatbelt', + 786: 'sewing machine', + 787: 'shield, buckler', + 788: 'shoe shop, shoe-shop, shoe store', + 789: 'shoji', + 790: 'shopping basket', + 791: 'shopping cart', + 792: 'shovel', + 793: 'shower cap', + 794: 'shower curtain', + 795: 'ski', + 796: 'ski mask', + 797: 'sleeping bag', + 798: 'slide rule, slipstick', + 799: 'sliding door', + 800: 'slot, one-armed bandit', + 801: 'snorkel', + 802: 'snowmobile', + 803: 'snowplow, snowplough', + 804: 'soap dispenser', + 805: 'soccer ball', + 806: 'sock', + 807: 'solar dish, solar collector, solar furnace', + 808: 'sombrero', + 809: 'soup bowl', + 810: 'space bar', + 811: 'space heater', + 812: 'space shuttle', + 813: 'spatula', + 814: 'speedboat', + 815: "spider web, spider's web", + 816: 'spindle', + 817: 'sports car, sport car', + 818: 'spotlight, spot', + 819: 'stage', + 820: 'steam locomotive', + 821: 'steel arch bridge', + 822: 'steel drum', + 823: 'stethoscope', + 824: 'stole', + 825: 'stone wall', + 826: 'stopwatch, stop watch', + 827: 'stove', + 828: 'strainer', + 829: 'streetcar, tram, tramcar, trolley, trolley car', + 830: 'stretcher', + 831: 'studio couch, day bed', + 832: 'stupa, tope', + 833: 'submarine, pigboat, sub, U-boat', + 834: 'suit, suit of clothes', + 835: 'sundial', + 836: 'sunglass', + 837: 'sunglasses, dark glasses, shades', + 838: 'sunscreen, sunblock, sun blocker', + 839: 'suspension bridge', + 840: 'swab, swob, mop', + 841: 'sweatshirt', + 842: 'swimming trunks, bathing trunks', + 843: 'swing', + 844: 'switch, electric switch, electrical switch', + 845: 'syringe', + 846: 'table lamp', + 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 848: 'tape player', + 849: 'teapot', + 850: 'teddy, teddy bear', + 851: 'television, television system', + 852: 'tennis ball', + 853: 'thatch, thatched roof', + 854: 'theater curtain, theatre curtain', + 855: 'thimble', + 856: 'thresher, thrasher, threshing machine', + 857: 'throne', + 858: 'tile roof', + 859: 'toaster', + 860: 'tobacco shop, tobacconist shop, tobacconist', + 861: 'toilet seat', + 862: 'torch', + 863: 'totem pole', + 864: 'tow truck, tow car, wrecker', + 865: 'toyshop', + 866: 'tractor', + 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', + 868: 'tray', + 869: 'trench coat', + 870: 'tricycle, trike, velocipede', + 871: 'trimaran', + 872: 'tripod', + 873: 'triumphal arch', + 874: 'trolleybus, trolley coach, trackless trolley', + 875: 'trombone', + 876: 'tub, vat', + 877: 'turnstile', + 878: 'typewriter keyboard', + 879: 'umbrella', + 880: 'unicycle, monocycle', + 881: 'upright, upright piano', + 882: 'vacuum, vacuum cleaner', + 883: 'vase', + 884: 'vault', + 885: 'velvet', + 886: 'vending machine', + 887: 'vestment', + 888: 'viaduct', + 889: 'violin, fiddle', + 890: 'volleyball', + 891: 'waffle iron', + 892: 'wall clock', + 893: 'wallet, billfold, notecase, pocketbook', + 894: 'wardrobe, closet, press', + 895: 'warplane, military plane', + 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 897: 'washer, automatic washer, washing machine', + 898: 'water bottle', + 899: 'water jug', + 900: 'water tower', + 901: 'whiskey jug', + 902: 'whistle', + 903: 'wig', + 904: 'window screen', + 905: 'window shade', + 906: 'Windsor tie', + 907: 'wine bottle', + 908: 'wing', + 909: 'wok', + 910: 'wooden spoon', + 911: 'wool, woolen, woollen', + 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', + 913: 'wreck', + 914: 'yawl', + 915: 'yurt', + 916: 'web site, website, internet site, site', + 917: 'comic book', + 918: 'crossword puzzle, crossword', + 919: 'street sign', + 920: 'traffic light, traffic signal, stoplight', + 921: 'book jacket, dust cover, dust jacket, dust wrapper', + 922: 'menu', + 923: 'plate', + 924: 'guacamole', + 925: 'consomme', + 926: 'hot pot, hotpot', + 927: 'trifle', + 928: 'ice cream, icecream', + 929: 'ice lolly, lolly, lollipop, popsicle', + 930: 'French loaf', + 931: 'bagel, beigel', + 932: 'pretzel', + 933: 'cheeseburger', + 934: 'hotdog, hot dog, red hot', + 935: 'mashed potato', + 936: 'head cabbage', + 937: 'broccoli', + 938: 'cauliflower', + 939: 'zucchini, courgette', + 940: 'spaghetti squash', + 941: 'acorn squash', + 942: 'butternut squash', + 943: 'cucumber, cuke', + 944: 'artichoke, globe artichoke', + 945: 'bell pepper', + 946: 'cardoon', + 947: 'mushroom', + 948: 'Granny Smith', + 949: 'strawberry', + 950: 'orange', + 951: 'lemon', + 952: 'fig', + 953: 'pineapple, ananas', + 954: 'banana', + 955: 'jackfruit, jak, jack', + 956: 'custard apple', + 957: 'pomegranate', + 958: 'hay', + 959: 'carbonara', + 960: 'chocolate sauce, chocolate syrup', + 961: 'dough', + 962: 'meat loaf, meatloaf', + 963: 'pizza, pizza pie', + 964: 'potpie', + 965: 'burrito', + 966: 'red wine', + 967: 'espresso', + 968: 'cup', + 969: 'eggnog', + 970: 'alp', + 971: 'bubble', + 972: 'cliff, drop, drop-off', + 973: 'coral reef', + 974: 'geyser', + 975: 'lakeside, lakeshore', + 976: 'promontory, headland, head, foreland', + 977: 'sandbar, sand bar', + 978: 'seashore, coast, seacoast, sea-coast', + 979: 'valley, vale', + 980: 'volcano', + 981: 'ballplayer, baseball player', + 982: 'groom, bridegroom', + 983: 'scuba diver', + 984: 'rapeseed', + 985: 'daisy', + 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + 987: 'corn', + 988: 'acorn', + 989: 'hip, rose hip, rosehip', + 990: 'buckeye, horse chestnut, conker', + 991: 'coral fungus', + 992: 'agaric', + 993: 'gyromitra', + 994: 'stinkhorn, carrion fungus', + 995: 'earthstar', + 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', + 997: 'bolete', + 998: 'ear, spike, capitulum', + 999: 'toilet tissue, toilet paper, bathroom tissue' \ No newline at end of file diff --git a/stable-diffusion/data/imagenet_train_hr_indices.p b/stable-diffusion/data/imagenet_train_hr_indices.p new file mode 100644 index 0000000..b8d6d46 Binary files /dev/null and b/stable-diffusion/data/imagenet_train_hr_indices.p differ diff --git a/stable-diffusion/data/imagenet_val_hr_indices.p b/stable-diffusion/data/imagenet_val_hr_indices.p new file mode 100644 index 0000000..744ad64 Binary files /dev/null and b/stable-diffusion/data/imagenet_val_hr_indices.p differ diff --git a/stable-diffusion/data/index_synset.yaml b/stable-diffusion/data/index_synset.yaml new file mode 100644 index 0000000..635ea71 --- /dev/null +++ b/stable-diffusion/data/index_synset.yaml @@ -0,0 +1,1000 @@ +0: n01440764 +1: n01443537 +2: n01484850 +3: n01491361 +4: n01494475 +5: n01496331 +6: n01498041 +7: n01514668 +8: n07646067 +9: n01518878 +10: n01530575 +11: n01531178 +12: n01532829 +13: n01534433 +14: n01537544 +15: n01558993 +16: n01560419 +17: n01580077 +18: n01582220 +19: n01592084 +20: n01601694 +21: n13382471 +22: n01614925 +23: n01616318 +24: n01622779 +25: n01629819 +26: n01630670 +27: n01631663 +28: n01632458 +29: n01632777 +30: n01641577 +31: n01644373 +32: n01644900 +33: n01664065 +34: n01665541 +35: n01667114 +36: n01667778 +37: n01669191 +38: n01675722 +39: n01677366 +40: n01682714 +41: n01685808 +42: n01687978 +43: n01688243 +44: n01689811 +45: n01692333 +46: n01693334 +47: n01694178 +48: n01695060 +49: n01697457 +50: n01698640 +51: n01704323 +52: n01728572 +53: n01728920 +54: n01729322 +55: n01729977 +56: n01734418 +57: n01735189 +58: n01737021 +59: n01739381 +60: n01740131 +61: n01742172 +62: n01744401 +63: n01748264 +64: n01749939 +65: n01751748 +66: n01753488 +67: n01755581 +68: n01756291 +69: n01768244 +70: n01770081 +71: n01770393 +72: n01773157 +73: n01773549 +74: n01773797 +75: n01774384 +76: n01774750 +77: n01775062 +78: n04432308 +79: n01784675 +80: n01795545 +81: n01796340 +82: n01797886 +83: n01798484 +84: n01806143 +85: n07647321 +86: n07647496 +87: n01817953 +88: n01818515 +89: n01819313 +90: n01820546 +91: n01824575 +92: n01828970 +93: n01829413 +94: n01833805 +95: n01843065 +96: n01843383 +97: n01847000 +98: n01855032 +99: n07646821 +100: n01860187 +101: n01871265 +102: n01872772 +103: n01873310 +104: n01877812 +105: n01882714 +106: n01883070 +107: n01910747 +108: n01914609 +109: n01917289 +110: n01924916 +111: n01930112 +112: n01943899 +113: n01944390 +114: n13719102 +115: n01950731 +116: n01955084 +117: n01968897 +118: n01978287 +119: n01978455 +120: n01980166 +121: n01981276 +122: n01983481 +123: n01984695 +124: n01985128 +125: n01986214 +126: n01990800 +127: n02002556 +128: n02002724 +129: n02006656 +130: n02007558 +131: n02009229 +132: n02009912 +133: n02011460 +134: n03126707 +135: n02013706 +136: n02017213 +137: n02018207 +138: n02018795 +139: n02025239 +140: n02027492 +141: n02028035 +142: n02033041 +143: n02037110 +144: n02051845 +145: n02056570 +146: n02058221 +147: n02066245 +148: n02071294 +149: n02074367 +150: n02077923 +151: n08742578 +152: n02085782 +153: n02085936 +154: n02086079 +155: n02086240 +156: n02086646 +157: n02086910 +158: n02087046 +159: n02087394 +160: n02088094 +161: n02088238 +162: n02088364 +163: n02088466 +164: n02088632 +165: n02089078 +166: n02089867 +167: n02089973 +168: n02090379 +169: n02090622 +170: n02090721 +171: n02091032 +172: n02091134 +173: n02091244 +174: n02091467 +175: n02091635 +176: n02091831 +177: n02092002 +178: n02092339 +179: n02093256 +180: n02093428 +181: n02093647 +182: n02093754 +183: n02093859 +184: n02093991 +185: n02094114 +186: n02094258 +187: n02094433 +188: n02095314 +189: n02095570 +190: n02095889 +191: n02096051 +192: n02096177 +193: n02096294 +194: n02096437 +195: n02096585 +196: n02097047 +197: n02097130 +198: n02097209 +199: n02097298 +200: n02097474 +201: n02097658 +202: n02098105 +203: n02098286 +204: n02098413 +205: n02099267 +206: n02099429 +207: n02099601 +208: n02099712 +209: n02099849 +210: n02100236 +211: n02100583 +212: n02100735 +213: n02100877 +214: n02101006 +215: n02101388 +216: n02101556 +217: n02102040 +218: n02102177 +219: n02102318 +220: n02102480 +221: n02102973 +222: n02104029 +223: n02104365 +224: n02105056 +225: n02105162 +226: n02105251 +227: n02105412 +228: n02105505 +229: n02105641 +230: n02105855 +231: n02106030 +232: n02106166 +233: n02106382 +234: n02106550 +235: n02106662 +236: n02107142 +237: n02107312 +238: n02107574 +239: n02107683 +240: n02107908 +241: n02108000 +242: n02108089 +243: n02108422 +244: n02108551 +245: n02108915 +246: n02109047 +247: n02109525 +248: n02109961 +249: n02110063 +250: n02110185 +251: n02110341 +252: n02110627 +253: n02110806 +254: n02110958 +255: n02111129 +256: n02111277 +257: n02111500 +258: n02111889 +259: n02112018 +260: n02112137 +261: n02112350 +262: n02112706 +263: n02113023 +264: n02113186 +265: n02113624 +266: n02113712 +267: n02113799 +268: n02113978 +269: n02114367 +270: n02114548 +271: n02114712 +272: n02114855 +273: n02115641 +274: n02115913 +275: n02116738 +276: n02117135 +277: n02119022 +278: n02119789 +279: n02120079 +280: n02120505 +281: n02123045 +282: n02123159 +283: n02123394 +284: n02123597 +285: n02124075 +286: n02125311 +287: n02127052 +288: n02128385 +289: n02128757 +290: n02128925 +291: n02129165 +292: n02129604 +293: n02130308 +294: n02132136 +295: n02133161 +296: n02134084 +297: n02134418 +298: n02137549 +299: n02138441 +300: n02165105 +301: n02165456 +302: n02167151 +303: n02168699 +304: n02169497 +305: n02172182 +306: n02174001 +307: n02177972 +308: n03373237 +309: n07975909 +310: n02219486 +311: n02226429 +312: n02229544 +313: n02231487 +314: n02233338 +315: n02236044 +316: n02256656 +317: n02259212 +318: n02264363 +319: n02268443 +320: n02268853 +321: n02276258 +322: n02277742 +323: n02279972 +324: n02280649 +325: n02281406 +326: n02281787 +327: n02317335 +328: n02319095 +329: n02321529 +330: n02325366 +331: n02326432 +332: n02328150 +333: n02342885 +334: n02346627 +335: n02356798 +336: n02361337 +337: n05262120 +338: n02364673 +339: n02389026 +340: n02391049 +341: n02395406 +342: n02396427 +343: n02397096 +344: n02398521 +345: n02403003 +346: n02408429 +347: n02410509 +348: n02412080 +349: n02415577 +350: n02417914 +351: n02422106 +352: n02422699 +353: n02423022 +354: n02437312 +355: n02437616 +356: n10771990 +357: n14765497 +358: n02443114 +359: n02443484 +360: n14765785 +361: n02445715 +362: n02447366 +363: n02454379 +364: n02457408 +365: n02480495 +366: n02480855 +367: n02481823 +368: n02483362 +369: n02483708 +370: n02484975 +371: n02486261 +372: n02486410 +373: n02487347 +374: n02488291 +375: n02488702 +376: n02489166 +377: n02490219 +378: n02492035 +379: n02492660 +380: n02493509 +381: n02493793 +382: n02494079 +383: n02497673 +384: n02500267 +385: n02504013 +386: n02504458 +387: n02509815 +388: n02510455 +389: n02514041 +390: n07783967 +391: n02536864 +392: n02606052 +393: n02607072 +394: n02640242 +395: n02641379 +396: n02643566 +397: n02655020 +398: n02666347 +399: n02667093 +400: n02669723 +401: n02672831 +402: n02676566 +403: n02687172 +404: n02690373 +405: n02692877 +406: n02699494 +407: n02701002 +408: n02704792 +409: n02708093 +410: n02727426 +411: n08496334 +412: n02747177 +413: n02749479 +414: n02769748 +415: n02776631 +416: n02777292 +417: n02782329 +418: n02783161 +419: n02786058 +420: n02787622 +421: n02788148 +422: n02790996 +423: n02791124 +424: n02791270 +425: n02793495 +426: n02794156 +427: n02795169 +428: n02797295 +429: n02799071 +430: n02802426 +431: n02804515 +432: n02804610 +433: n02807133 +434: n02808304 +435: n02808440 +436: n02814533 +437: n02814860 +438: n02815834 +439: n02817516 +440: n02823428 +441: n02823750 +442: n02825657 +443: n02834397 +444: n02835271 +445: n02837789 +446: n02840245 +447: n02841315 +448: n02843684 +449: n02859443 +450: n02860847 +451: n02865351 +452: n02869837 +453: n02870880 +454: n02871525 +455: n02877765 +456: n02880308 +457: n02883205 +458: n02892201 +459: n02892767 +460: n02894605 +461: n02895154 +462: n12520864 +463: n02909870 +464: n02910353 +465: n02916936 +466: n02917067 +467: n02927161 +468: n02930766 +469: n02939185 +470: n02948072 +471: n02950826 +472: n02951358 +473: n02951585 +474: n02963159 +475: n02965783 +476: n02966193 +477: n02966687 +478: n02971356 +479: n02974003 +480: n02977058 +481: n02978881 +482: n02979186 +483: n02980441 +484: n02981792 +485: n02988304 +486: n02992211 +487: n02992529 +488: n13652994 +489: n03000134 +490: n03000247 +491: n03000684 +492: n03014705 +493: n03016953 +494: n03017168 +495: n03018349 +496: n03026506 +497: n03028079 +498: n03032252 +499: n03041632 +500: n03042490 +501: n03045698 +502: n03047690 +503: n03062245 +504: n03063599 +505: n03063689 +506: n03065424 +507: n03075370 +508: n03085013 +509: n03089624 +510: n03095699 +511: n03100240 +512: n03109150 +513: n03110669 +514: n03124043 +515: n03124170 +516: n15142452 +517: n03126707 +518: n03127747 +519: n03127925 +520: n03131574 +521: n03133878 +522: n03134739 +523: n03141823 +524: n03146219 +525: n03160309 +526: n03179701 +527: n03180011 +528: n03187595 +529: n03188531 +530: n03196217 +531: n03197337 +532: n03201208 +533: n03207743 +534: n03207941 +535: n03208938 +536: n03216828 +537: n03218198 +538: n13872072 +539: n03223299 +540: n03240683 +541: n03249569 +542: n07647870 +543: n03255030 +544: n03259401 +545: n03271574 +546: n03272010 +547: n03272562 +548: n03290653 +549: n13869788 +550: n03297495 +551: n03314780 +552: n03325584 +553: n03337140 +554: n03344393 +555: n03345487 +556: n03347037 +557: n03355925 +558: n03372029 +559: n03376595 +560: n03379051 +561: n03384352 +562: n03388043 +563: n03388183 +564: n03388549 +565: n03393912 +566: n03394916 +567: n03400231 +568: n03404251 +569: n03417042 +570: n03424325 +571: n03425413 +572: n03443371 +573: n03444034 +574: n03445777 +575: n03445924 +576: n03447447 +577: n03447721 +578: n08286342 +579: n03452741 +580: n03457902 +581: n03459775 +582: n03461385 +583: n03467068 +584: n03476684 +585: n03476991 +586: n03478589 +587: n03482001 +588: n03482405 +589: n03483316 +590: n03485407 +591: n03485794 +592: n03492542 +593: n03494278 +594: n03495570 +595: n10161363 +596: n03498962 +597: n03527565 +598: n03529860 +599: n09218315 +600: n03532672 +601: n03534580 +602: n03535780 +603: n03538406 +604: n03544143 +605: n03584254 +606: n03584829 +607: n03590841 +608: n03594734 +609: n03594945 +610: n03595614 +611: n03598930 +612: n03599486 +613: n03602883 +614: n03617480 +615: n03623198 +616: n15102712 +617: n03630383 +618: n03633091 +619: n03637318 +620: n03642806 +621: n03649909 +622: n03657121 +623: n03658185 +624: n07977870 +625: n03662601 +626: n03666591 +627: n03670208 +628: n03673027 +629: n03676483 +630: n03680355 +631: n03690938 +632: n03691459 +633: n03692522 +634: n03697007 +635: n03706229 +636: n03709823 +637: n03710193 +638: n03710637 +639: n03710721 +640: n03717622 +641: n03720891 +642: n03721384 +643: n03725035 +644: n03729826 +645: n03733131 +646: n03733281 +647: n03733805 +648: n03742115 +649: n03743016 +650: n03759954 +651: n03761084 +652: n03763968 +653: n03764736 +654: n03769881 +655: n03770439 +656: n03770679 +657: n03773504 +658: n03775071 +659: n03775546 +660: n03776460 +661: n03777568 +662: n03777754 +663: n03781244 +664: n03782006 +665: n03785016 +666: n14955889 +667: n03787032 +668: n03788195 +669: n03788365 +670: n03791053 +671: n03792782 +672: n03792972 +673: n03793489 +674: n03794056 +675: n03796401 +676: n03803284 +677: n13652335 +678: n03814639 +679: n03814906 +680: n03825788 +681: n03832673 +682: n03837869 +683: n03838899 +684: n03840681 +685: n03841143 +686: n03843555 +687: n03854065 +688: n03857828 +689: n03866082 +690: n03868242 +691: n03868863 +692: n07281099 +693: n03873416 +694: n03874293 +695: n03874599 +696: n03876231 +697: n03877472 +698: n08053121 +699: n03884397 +700: n03887697 +701: n03888257 +702: n03888605 +703: n03891251 +704: n03891332 +705: n03895866 +706: n03899768 +707: n03902125 +708: n03903868 +709: n03908618 +710: n03908714 +711: n03916031 +712: n03920288 +713: n03924679 +714: n03929660 +715: n03929855 +716: n03930313 +717: n03930630 +718: n03934042 +719: n03935335 +720: n03937543 +721: n03938244 +722: n03942813 +723: n03944341 +724: n03947888 +725: n03950228 +726: n03954731 +727: n03956157 +728: n03958227 +729: n03961711 +730: n03967562 +731: n03970156 +732: n03976467 +733: n08620881 +734: n03977966 +735: n03980874 +736: n03982430 +737: n03983396 +738: n03991062 +739: n03992509 +740: n03995372 +741: n03998194 +742: n04004767 +743: n13937284 +744: n04008634 +745: n04009801 +746: n04019541 +747: n04023962 +748: n13413294 +749: n04033901 +750: n04033995 +751: n04037443 +752: n04039381 +753: n09403211 +754: n04041544 +755: n04044716 +756: n04049303 +757: n04065272 +758: n07056680 +759: n04069434 +760: n04070727 +761: n04074963 +762: n04081281 +763: n04086273 +764: n04090263 +765: n04099969 +766: n04111531 +767: n04116512 +768: n04118538 +769: n04118776 +770: n04120489 +771: n04125116 +772: n04127249 +773: n04131690 +774: n04133789 +775: n04136333 +776: n04141076 +777: n04141327 +778: n04141975 +779: n04146614 +780: n04147291 +781: n04149813 +782: n04152593 +783: n04154340 +784: n07917272 +785: n04162706 +786: n04179913 +787: n04192698 +788: n04200800 +789: n04201297 +790: n04204238 +791: n04204347 +792: n04208427 +793: n04209133 +794: n04209239 +795: n04228054 +796: n04229816 +797: n04235860 +798: n04238763 +799: n04239074 +800: n04243546 +801: n04251144 +802: n04252077 +803: n04252225 +804: n04254120 +805: n04254680 +806: n04254777 +807: n04258138 +808: n04259630 +809: n04263257 +810: n04264628 +811: n04265275 +812: n04266014 +813: n04270147 +814: n04273569 +815: n04275363 +816: n05605498 +817: n04285008 +818: n04286575 +819: n08646566 +820: n04310018 +821: n04311004 +822: n04311174 +823: n04317175 +824: n04325704 +825: n04326547 +826: n04328186 +827: n04330267 +828: n04332243 +829: n04335435 +830: n04337157 +831: n04344873 +832: n04346328 +833: n04347754 +834: n04350905 +835: n04355338 +836: n04355933 +837: n04356056 +838: n04357314 +839: n04366367 +840: n04367480 +841: n04370456 +842: n04371430 +843: n14009946 +844: n04372370 +845: n04376876 +846: n04380533 +847: n04389033 +848: n04392985 +849: n04398044 +850: n04399382 +851: n04404412 +852: n04409515 +853: n04417672 +854: n04418357 +855: n04423845 +856: n04428191 +857: n04429376 +858: n04435653 +859: n04442312 +860: n04443257 +861: n04447861 +862: n04456115 +863: n04458633 +864: n04461696 +865: n04462240 +866: n04465666 +867: n04467665 +868: n04476259 +869: n04479046 +870: n04482393 +871: n04483307 +872: n04485082 +873: n04486054 +874: n04487081 +875: n04487394 +876: n04493381 +877: n04501370 +878: n04505470 +879: n04507155 +880: n04509417 +881: n04515003 +882: n04517823 +883: n04522168 +884: n04523525 +885: n04525038 +886: n04525305 +887: n04532106 +888: n04532670 +889: n04536866 +890: n04540053 +891: n04542943 +892: n04548280 +893: n04548362 +894: n04550184 +895: n04552348 +896: n04553703 +897: n04554684 +898: n04557648 +899: n04560804 +900: n04562935 +901: n04579145 +902: n04579667 +903: n04584207 +904: n04589890 +905: n04590129 +906: n04591157 +907: n04591713 +908: n10782135 +909: n04596742 +910: n04598010 +911: n04599235 +912: n04604644 +913: n14423870 +914: n04612504 +915: n04613696 +916: n06359193 +917: n06596364 +918: n06785654 +919: n06794110 +920: n06874185 +921: n07248320 +922: n07565083 +923: n07657664 +924: n07583066 +925: n07584110 +926: n07590611 +927: n07613480 +928: n07614500 +929: n07615774 +930: n07684084 +931: n07693725 +932: n07695742 +933: n07697313 +934: n07697537 +935: n07711569 +936: n07714571 +937: n07714990 +938: n07715103 +939: n12159804 +940: n12160303 +941: n12160857 +942: n07717556 +943: n07718472 +944: n07718747 +945: n07720875 +946: n07730033 +947: n13001041 +948: n07742313 +949: n12630144 +950: n14991210 +951: n07749582 +952: n07753113 +953: n07753275 +954: n07753592 +955: n07754684 +956: n07760859 +957: n07768694 +958: n07802026 +959: n07831146 +960: n07836838 +961: n07860988 +962: n07871810 +963: n07873807 +964: n07875152 +965: n07880968 +966: n07892512 +967: n07920052 +968: n13904665 +969: n07932039 +970: n09193705 +971: n09229709 +972: n09246464 +973: n09256479 +974: n09288635 +975: n09332890 +976: n09399592 +977: n09421951 +978: n09428293 +979: n09468604 +980: n09472597 +981: n09835506 +982: n10148035 +983: n10565667 +984: n11879895 +985: n11939491 +986: n12057211 +987: n12144580 +988: n12267677 +989: n12620546 +990: n12768682 +991: n12985857 +992: n12998815 +993: n13037406 +994: n13040303 +995: n13044778 +996: n13052670 +997: n13054560 +998: n13133613 +999: n15075141 diff --git a/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png b/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png new file mode 100644 index 0000000..3eb5a22 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png differ diff --git a/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png b/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png new file mode 100644 index 0000000..6c77130 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png b/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png new file mode 100644 index 0000000..63ac989 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png differ diff --git a/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png b/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png new file mode 100644 index 0000000..7eb67e4 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png b/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png new file mode 100644 index 0000000..7714a1f Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png differ diff --git a/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png b/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png new file mode 100644 index 0000000..0324f67 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/bench2.png b/stable-diffusion/data/inpainting_examples/bench2.png new file mode 100644 index 0000000..09be46d Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/bench2.png differ diff --git a/stable-diffusion/data/inpainting_examples/bench2_mask.png b/stable-diffusion/data/inpainting_examples/bench2_mask.png new file mode 100644 index 0000000..bacadfa Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/bench2_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png b/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png new file mode 100644 index 0000000..618f200 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png differ diff --git a/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png b/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png new file mode 100644 index 0000000..fd18be9 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png b/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png new file mode 100644 index 0000000..cbd246e Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png differ diff --git a/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png b/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png new file mode 100644 index 0000000..7e51214 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png b/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png new file mode 100644 index 0000000..e84dfc8 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png differ diff --git a/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png b/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png new file mode 100644 index 0000000..7f3c753 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png differ diff --git a/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png b/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png new file mode 100644 index 0000000..e8999de Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png differ diff --git a/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png b/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png new file mode 100644 index 0000000..093d0c1 Binary files /dev/null and b/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png differ diff --git a/stable-diffusion/environment.yaml b/stable-diffusion/environment.yaml new file mode 100644 index 0000000..025ced8 --- /dev/null +++ b/stable-diffusion/environment.yaml @@ -0,0 +1,31 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.19.2 + - pip: + - albumentations==0.4.3 + - diffusers + - opencv-python==4.1.2.30 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.19.2 + - torchmetrics==0.6.0 + - kornia==0.6 + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e . diff --git a/stable-diffusion/get_test_data_df.py b/stable-diffusion/get_test_data_df.py new file mode 100644 index 0000000..dcbecbc --- /dev/null +++ b/stable-diffusion/get_test_data_df.py @@ -0,0 +1,36 @@ +import glob +import os + +import argparse +import random + +parser = argparse.ArgumentParser() + +parser.add_argument('--test_data_dir', type=str,default='../test_data') +parser.add_argument('--df_ckpt', type=str,default='SG161222/Realistic_Vision_V5.1_noVAE') +parser.add_argument('--sample_num', type=int,default=6) +parser.add_argument( + "--scale", + type=float, + default=5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + +opt = parser.parse_args() +test_data_dir = opt.test_data_dir +df_ckpt = opt.df_ckpt +sample_num = opt.sample_num +scale = opt.scale + +for sub_dir in glob.glob(os.path.join(test_data_dir, '*')): + prompt_path = os.path.join(sub_dir, 'prompt.txt') + if os.path.exists(prompt_path): + with open(prompt_path, 'r') as f: + prompt = f.read().strip() + # print(prompt) + samples_dir = os.path.join(sub_dir, 'samples') + seed = random.randint(0, 100000) + if not os.path.exists(samples_dir): + cmd = f'python scripts/txt2realistic_human.py --outdir {sub_dir} --seed {seed} --H 512 --W 512 --n_samples 1 --scale {scale} --n_iter {sample_num} --prompt "{prompt}" --plms --ckpt {df_ckpt}' + print(cmd) + os.system(cmd) diff --git a/stable-diffusion/ldm/data/__init__.py b/stable-diffusion/ldm/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-diffusion/ldm/data/base.py b/stable-diffusion/ldm/data/base.py new file mode 100644 index 0000000..b196c2f --- /dev/null +++ b/stable-diffusion/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/stable-diffusion/ldm/data/imagenet.py b/stable-diffusion/ldm/data/imagenet.py new file mode 100644 index 0000000..1c473f9 --- /dev/null +++ b/stable-diffusion/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/stable-diffusion/ldm/data/lsun.py b/stable-diffusion/ldm/data/lsun.py new file mode 100644 index 0000000..6256e45 --- /dev/null +++ b/stable-diffusion/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/stable-diffusion/ldm/lr_scheduler.py b/stable-diffusion/ldm/lr_scheduler.py new file mode 100644 index 0000000..be39da9 --- /dev/null +++ b/stable-diffusion/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/stable-diffusion/ldm/models/autoencoder.py b/stable-diffusion/ldm/models/autoencoder.py new file mode 100644 index 0000000..6a9c4f4 --- /dev/null +++ b/stable-diffusion/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/stable-diffusion/ldm/models/diffusion/__init__.py b/stable-diffusion/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-diffusion/ldm/models/diffusion/classifier.py b/stable-diffusion/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000..67e98b9 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/stable-diffusion/ldm/models/diffusion/ddim.py b/stable-diffusion/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000..fb31215 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/ddim.py @@ -0,0 +1,241 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/stable-diffusion/ldm/models/diffusion/ddpm.py b/stable-diffusion/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000..bbedd04 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1445 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000..7427f38 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/stable-diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/stable-diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000..bdb64e0 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1184 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000..2c42d6f --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,82 @@ +"""SAMPLING ONLY.""" + +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None diff --git a/stable-diffusion/ldm/models/diffusion/plms.py b/stable-diffusion/ldm/models/diffusion/plms.py new file mode 100644 index 0000000..78eeb10 --- /dev/null +++ b/stable-diffusion/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/stable-diffusion/ldm/modules/attention.py b/stable-diffusion/ldm/modules/attention.py new file mode 100644 index 0000000..f4eff39 --- /dev/null +++ b/stable-diffusion/ldm/modules/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/stable-diffusion/ldm/modules/diffusionmodules/__init__.py b/stable-diffusion/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-diffusion/ldm/modules/diffusionmodules/model.py b/stable-diffusion/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000..533e589 --- /dev/null +++ b/stable-diffusion/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py b/stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000..fcf95d1 --- /dev/null +++ b/stable-diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,961 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/stable-diffusion/ldm/modules/diffusionmodules/util.py b/stable-diffusion/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000..a952e6c --- /dev/null +++ b/stable-diffusion/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/stable-diffusion/ldm/modules/distributions/__init__.py b/stable-diffusion/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-diffusion/ldm/modules/distributions/distributions.py b/stable-diffusion/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000..f2b8ef9 --- /dev/null +++ b/stable-diffusion/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/stable-diffusion/ldm/modules/ema.py b/stable-diffusion/ldm/modules/ema.py new file mode 100644 index 0000000..c8c75af --- /dev/null +++ b/stable-diffusion/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/stable-diffusion/ldm/modules/encoders/__init__.py b/stable-diffusion/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-diffusion/ldm/modules/encoders/modules.py b/stable-diffusion/ldm/modules/encoders/modules.py new file mode 100644 index 0000000..ededbe4 --- /dev/null +++ b/stable-diffusion/ldm/modules/encoders/modules.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/stable-diffusion/ldm/modules/image_degradation/__init__.py b/stable-diffusion/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000..7836cad --- /dev/null +++ b/stable-diffusion/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/stable-diffusion/ldm/modules/image_degradation/bsrgan.py b/stable-diffusion/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000..32ef561 --- /dev/null +++ b/stable-diffusion/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/stable-diffusion/ldm/modules/image_degradation/bsrgan_light.py b/stable-diffusion/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000..9e1f823 --- /dev/null +++ b/stable-diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/stable-diffusion/ldm/modules/image_degradation/utils/test.png b/stable-diffusion/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000..4249b43 Binary files /dev/null and b/stable-diffusion/ldm/modules/image_degradation/utils/test.png differ diff --git a/stable-diffusion/ldm/modules/image_degradation/utils_image.py b/stable-diffusion/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000..0175f15 --- /dev/null +++ b/stable-diffusion/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/stable-diffusion/ldm/modules/losses/__init__.py b/stable-diffusion/ldm/modules/losses/__init__.py new file mode 100644 index 0000000..876d7c5 --- /dev/null +++ b/stable-diffusion/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/stable-diffusion/ldm/modules/losses/contperceptual.py b/stable-diffusion/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000..672c1e3 --- /dev/null +++ b/stable-diffusion/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/stable-diffusion/ldm/modules/losses/vqperceptual.py b/stable-diffusion/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000..f699817 --- /dev/null +++ b/stable-diffusion/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/stable-diffusion/ldm/modules/x_transformer.py b/stable-diffusion/ldm/modules/x_transformer.py new file mode 100644 index 0000000..5fc15bf --- /dev/null +++ b/stable-diffusion/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/stable-diffusion/ldm/util.py b/stable-diffusion/ldm/util.py new file mode 100644 index 0000000..8ba3885 --- /dev/null +++ b/stable-diffusion/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/stable-diffusion/main.py b/stable-diffusion/main.py new file mode 100644 index 0000000..e8e18c1 --- /dev/null +++ b/stable-diffusion/main.py @@ -0,0 +1,741 @@ +import argparse, os, sys, datetime, glob, importlib, csv +import numpy as np +import time +import torch +import torchvision +import pytorch_lightning as pl + +from packaging import version +from omegaconf import OmegaConf +from torch.utils.data import random_split, DataLoader, Dataset, Subset +from functools import partial +from PIL import Image + +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import rank_zero_info + +from ldm.data.base import Txt2ImgIterableBaseDataset +from ldm.util import instantiate_from_config + + +def get_parser(**parser_kwargs): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser(**parser_kwargs) + parser.add_argument( + "-n", + "--name", + type=str, + const=True, + default="", + nargs="?", + help="postfix for logdir", + ) + parser.add_argument( + "-r", + "--resume", + type=str, + const=True, + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", + default=list(), + ) + parser.add_argument( + "-t", + "--train", + type=str2bool, + const=True, + default=False, + nargs="?", + help="train", + ) + parser.add_argument( + "--no-test", + type=str2bool, + const=True, + default=False, + nargs="?", + help="disable test", + ) + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project" + ) + parser.add_argument( + "-d", + "--debug", + type=str2bool, + nargs="?", + const=True, + default=False, + help="enable post-mortem debugging", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + default=23, + help="seed for seed_everything", + ) + parser.add_argument( + "-f", + "--postfix", + type=str, + default="", + help="post-postfix for default name", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + default="logs", + help="directory for logging dat shit", + ) + parser.add_argument( + "--scale_lr", + type=str2bool, + nargs="?", + const=True, + default=True, + help="scale base-lr by ngpu * batch_size * n_accumulate", + ) + return parser + + +def nondefault_trainer_args(opt): + parser = argparse.ArgumentParser() + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args([]) + return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, Txt2ImgIterableBaseDataset): + split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + else: + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, + wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + shuffle_val_dataloader=False): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = dict() + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = dict( + (k, instantiate_from_config(self.dataset_configs[k])) + for k in self.dataset_configs) + if self.wrap: + for k in self.datasets: + self.datasets[k] = WrappedDataset(self.datasets[k]) + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["train"], batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) + + def _val_dataloader(self, shuffle=False): + if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoader(self.datasets["test"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + + def _predict_dataloader(self, shuffle=False): + if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader(self.datasets["predict"], batch_size=self.batch_size, + num_workers=self.num_workers, worker_init_fn=init_fn) + + +class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + self.config = config + self.lightning_config = lightning_config + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_pretrain_routine_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + if "callbacks" in self.lightning_config: + if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: + os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + print("Project config") + print(OmegaConf.to_yaml(self.config)) + OmegaConf.save(self.config, + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + + print("Lightning config") + print(OmegaConf.to_yaml(self.lightning_config)) + OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + try: + os.rename(self.logdir, dst) + except FileNotFoundError: + pass + + +class ImageLogger(Callback): + def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, + rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + log_images_kwargs=None): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = { + pl.loggers.TestTubeLogger: self._testtube, + } + self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def _testtube(self, pl_module, images, batch_idx, split): + for k in images: + grid = torchvision.utils.make_grid(images[k]) + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + + tag = f"{split}/{k}" + pl_module.logger.experiment.add_image( + tag, grid, + global_step=pl_module.global_step) + + @rank_zero_only + def log_local(self, save_dir, split, images, + global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( + k, + global_step, + current_epoch, + batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and + callable(pl_module.log_images) and + self.max_images > 0): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1., 1.) + + self.log_local(pl_module.logger.save_dir, split, images, + pl_module.global_step, pl_module.current_epoch, batch_idx) + + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step): + try: + self.log_steps.pop(0) + except IndexError as e: + print(e) + pass + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, 'calibrate_grad_norm'): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(trainer.root_gpu) + torch.cuda.synchronize(trainer.root_gpu) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module, outputs): + torch.cuda.synchronize(trainer.root_gpu) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +if __name__ == "__main__": + # custom parser to specify config files, train, test and debug mode, + # postfix, resume. + # `--key value` arguments are interpreted as arguments to the trainer. + # `nested.key=value` arguments are interpreted as config parameters. + # configs are merged from left-to-right followed by command line parameters. + + # model: + # base_learning_rate: float + # target: path to lightning module + # params: + # key: value + # data: + # target: main.DataModuleFromConfig + # params: + # batch_size: int + # wrap: bool + # train: + # target: path to train dataset + # params: + # key: value + # validation: + # target: path to validation dataset + # params: + # key: value + # test: + # target: path to test dataset + # params: + # key: value + # lightning: (optional, has sane defaults and can be specified on cmdline) + # trainer: + # additional arguments to trainer + # logger: + # logger to instantiate + # modelcheckpoint: + # modelcheckpoint to instantiate + # callbacks: + # callback1: + # target: importpath + # params: + # key: value + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + # add cwd for convenience and to make classes in this file available when + # running as `python main.py` + # (in particular `main.DataModuleFromConfig`) + sys.path.append(os.getcwd()) + + parser = get_parser() + parser = Trainer.add_argparse_args(parser) + + opt, unknown = parser.parse_known_args() + if opt.name and opt.resume: + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) + if opt.resume: + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + paths = opt.resume.split("/") + # idx = len(paths)-paths[::-1].index("logs")+1 + # logdir = "/".join(paths[:idx]) + logdir = "/".join(paths[:-2]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), opt.resume + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") + + opt.resume_from_checkpoint = ckpt + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) + opt.base = base_configs + opt.base + _tmp = logdir.split("/") + nowname = _tmp[-1] + else: + if opt.name: + name = "_" + opt.name + elif opt.base: + cfg_fname = os.path.split(opt.base[0])[-1] + cfg_name = os.path.splitext(cfg_fname)[0] + name = "_" + cfg_name + else: + name = "" + nowname = now + name + opt.postfix + logdir = os.path.join(opt.logdir, nowname) + + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") + seed_everything(opt.seed) + + try: + # init and save configs + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + lightning_config = config.pop("lightning", OmegaConf.create()) + # merge trainer cli with config + trainer_config = lightning_config.get("trainer", OmegaConf.create()) + # default to ddp + trainer_config["accelerator"] = "ddp" + for k in nondefault_trainer_args(opt): + trainer_config[k] = getattr(opt, k) + if not "gpus" in trainer_config: + del trainer_config["accelerator"] + cpu = True + else: + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") + cpu = False + trainer_opt = argparse.Namespace(**trainer_config) + lightning_config.trainer = trainer_config + + # model + model = instantiate_from_config(config.model) + + # trainer and callbacks + trainer_kwargs = dict() + + # default logger configs + default_logger_cfgs = { + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + } + }, + "testtube": { + "target": "pytorch_lightning.loggers.TestTubeLogger", + "params": { + "name": "testtube", + "save_dir": logdir, + } + }, + } + default_logger_cfg = default_logger_cfgs["testtube"] + if "logger" in lightning_config: + logger_cfg = lightning_config.logger + else: + logger_cfg = OmegaConf.create() + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + + # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to + # specify which metric is used to determine best models + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } + } + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 3 + + if "modelcheckpoint" in lightning_config: + modelckpt_cfg = lightning_config.modelcheckpoint + else: + modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse('1.4.0'): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, + } + }, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True + } + }, + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + } + }, + "cuda_callback": { + "target": "main.CUDACallback" + }, + } + if version.parse(pl.__version__) >= version.parse('1.4.0'): + default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) + + if "callbacks" in lightning_config: + callbacks_cfg = lightning_config.callbacks + else: + callbacks_cfg = OmegaConf.create() + + if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + print( + 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') + default_metrics_over_trainsteps_ckpt_dict = { + 'metrics_over_trainsteps_checkpoint': + {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } + } + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) + + callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): + callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint + elif 'ignore_keys_callback' in callbacks_cfg: + del callbacks_cfg['ignore_keys_callback'] + + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + + trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + trainer.logdir = logdir ### + + # data + data = instantiate_from_config(config.data) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + print("#### Data #####") + for k in data.datasets: + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + + # configure learning rate + bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate + if not cpu: + ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + else: + ngpu = 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches + if opt.scale_lr: + model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + print( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + else: + model.learning_rate = base_lr + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") + + + # allow checkpointing via USR1 + def melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + + def divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb; + pudb.set_trace() + + + import signal + + signal.signal(signal.SIGUSR1, melk) + signal.signal(signal.SIGUSR2, divein) + + # run + if opt.train: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + if not opt.no_test and not trainer.interrupted: + trainer.test(model, data) + except Exception: + if opt.debug and trainer.global_rank == 0: + try: + import pudb as debugger + except ImportError: + import pdb as debugger + debugger.post_mortem() + raise + finally: + # move newly created debug project to debug_runs + if opt.debug and not opt.resume and trainer.global_rank == 0: + dst, name = os.path.split(logdir) + dst = os.path.join(dst, "debug_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + os.rename(logdir, dst) + if trainer.global_rank == 0: + print(trainer.profiler.summary()) diff --git a/stable-diffusion/models/first_stage_models/kl-f16/config.yaml b/stable-diffusion/models/first_stage_models/kl-f16/config.yaml new file mode 100644 index 0000000..661921c --- /dev/null +++ b/stable-diffusion/models/first_stage_models/kl-f16/config.yaml @@ -0,0 +1,44 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 16 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 6 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/kl-f32/config.yaml b/stable-diffusion/models/first_stage_models/kl-f32/config.yaml new file mode 100644 index 0000000..7b642b1 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/kl-f32/config.yaml @@ -0,0 +1,46 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 64 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 64 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + - 8 + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 6 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/kl-f4/config.yaml b/stable-diffusion/models/first_stage_models/kl-f4/config.yaml new file mode 100644 index 0000000..85cfb3e --- /dev/null +++ b/stable-diffusion/models/first_stage_models/kl-f4/config.yaml @@ -0,0 +1,41 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 3 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/kl-f8/config.yaml b/stable-diffusion/models/first_stage_models/kl-f8/config.yaml new file mode 100644 index 0000000..921aa42 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/kl-f8/config.yaml @@ -0,0 +1,42 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 4 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/vq-f16/config.yaml b/stable-diffusion/models/first_stage_models/vq-f16/config.yaml new file mode 100644 index 0000000..91c7454 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/vq-f16/config.yaml @@ -0,0 +1,49 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 8 + n_embed: 16384 + ddconfig: + double_z: false + z_channels: 8 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 250001 + disc_weight: 0.75 + disc_num_layers: 2 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 14 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml b/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml new file mode 100644 index 0000000..f8e499f --- /dev/null +++ b/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml @@ -0,0 +1,46 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + + ddconfig: + attn_type: none + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 11 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 12 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/vq-f4/config.yaml b/stable-diffusion/models/first_stage_models/vq-f4/config.yaml new file mode 100644 index 0000000..7d8cef3 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/vq-f4/config.yaml @@ -0,0 +1,45 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 0 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 16 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml b/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml new file mode 100644 index 0000000..8519e13 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml @@ -0,0 +1,48 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 4 + n_embed: 256 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 250001 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/first_stage_models/vq-f8/config.yaml b/stable-diffusion/models/first_stage_models/vq-f8/config.yaml new file mode 100644 index 0000000..efd6801 --- /dev/null +++ b/stable-diffusion/models/first_stage_models/vq-f8/config.yaml @@ -0,0 +1,48 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 4 + n_embed: 16384 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_num_layers: 2 + disc_start: 1 + disc_weight: 0.6 + codebook_weight: 1.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/stable-diffusion/models/ldm/bsr_sr/config.yaml b/stable-diffusion/models/ldm/bsr_sr/config.yaml new file mode 100644 index 0000000..861692a --- /dev/null +++ b/stable-diffusion/models/ldm/bsr_sr/config.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + log_every_t: 100 + timesteps: 1000 + loss_type: l2 + first_stage_key: image + cond_stage_key: LR_image + image_size: 64 + channels: 3 + concat_mode: true + cond_stage_trainable: false + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 6 + out_channels: 3 + model_channels: 160 + attention_resolutions: + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 2 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: torch.nn.Identity +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: false + num_workers: 12 + train: + target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain + params: + size: 256 + degradation: bsrgan_light + downscale_f: 4 + min_crop_f: 0.5 + max_crop_f: 1.0 + random_crop: true + validation: + target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation + params: + size: 256 + degradation: bsrgan_light + downscale_f: 4 + min_crop_f: 0.5 + max_crop_f: 1.0 + random_crop: true diff --git a/stable-diffusion/models/ldm/celeba256/config.yaml b/stable-diffusion/models/ldm/celeba256/config.yaml new file mode 100644 index 0000000..a12f4e9 --- /dev/null +++ b/stable-diffusion/models/ldm/celeba256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.faceshq.CelebAHQTrain + params: + size: 256 + validation: + target: ldm.data.faceshq.CelebAHQValidation + params: + size: 256 diff --git a/stable-diffusion/models/ldm/cin256/config.yaml b/stable-diffusion/models/ldm/cin256/config.yaml new file mode 100644 index 0000000..9bc1b45 --- /dev/null +++ b/stable-diffusion/models/ldm/cin256/config.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 4 + n_embed: 16384 + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 512 + key: class_label +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + num_workers: 12 + wrap: false + train: + target: ldm.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: ldm.data.imagenet.ImageNetValidation + params: + config: + size: 256 diff --git a/stable-diffusion/models/ldm/ffhq256/config.yaml b/stable-diffusion/models/ldm/ffhq256/config.yaml new file mode 100644 index 0000000..0ddfd1b --- /dev/null +++ b/stable-diffusion/models/ldm/ffhq256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 42 + num_workers: 5 + wrap: false + train: + target: ldm.data.faceshq.FFHQTrain + params: + size: 256 + validation: + target: ldm.data.faceshq.FFHQValidation + params: + size: 256 diff --git a/stable-diffusion/models/ldm/inpainting_big/config.yaml b/stable-diffusion/models/ldm/inpainting_big/config.yaml new file mode 100644 index 0000000..da5fd5e --- /dev/null +++ b/stable-diffusion/models/ldm/inpainting_big/config.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: masked_image + image_size: 64 + channels: 3 + concat_mode: true + monitor: val/loss + scheduler_config: + target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler + params: + verbosity_interval: 0 + warm_up_steps: 1000 + max_decay_steps: 50000 + lr_start: 0.001 + lr_max: 0.1 + lr_min: 0.0001 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 7 + out_channels: 3 + model_channels: 256 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_heads: 8 + resblock_updown: true + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + attn_type: none + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: ldm.modules.losses.contperceptual.DummyLoss + cond_stage_config: __is_first_stage__ diff --git a/stable-diffusion/models/ldm/layout2img-openimages256/config.yaml b/stable-diffusion/models/ldm/layout2img-openimages256/config.yaml new file mode 100644 index 0000000..9e1dc15 --- /dev/null +++ b/stable-diffusion/models/ldm/layout2img-openimages256/config.yaml @@ -0,0 +1,81 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: coordinates_bbox + image_size: 64 + channels: 3 + conditioning_key: crossattn + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 3 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 512 + n_layer: 16 + vocab_size: 8192 + max_seq_len: 92 + use_tokenizer: false + monitor: val/loss_simple_ema +data: + target: main.DataModuleFromConfig + params: + batch_size: 24 + wrap: false + num_workers: 10 + train: + target: ldm.data.openimages.OpenImagesBBoxTrain + params: + size: 256 + validation: + target: ldm.data.openimages.OpenImagesBBoxValidation + params: + size: 256 diff --git a/stable-diffusion/models/ldm/lsun_beds256/config.yaml b/stable-diffusion/models/ldm/lsun_beds256/config.yaml new file mode 100644 index 0000000..1a50c76 --- /dev/null +++ b/stable-diffusion/models/ldm/lsun_beds256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNBedroomsTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNBedroomsValidation + params: + size: 256 diff --git a/stable-diffusion/models/ldm/lsun_churches256/config.yaml b/stable-diffusion/models/ldm/lsun_churches256/config.yaml new file mode 100644 index 0000000..424d091 --- /dev/null +++ b/stable-diffusion/models/ldm/lsun_churches256/config.yaml @@ -0,0 +1,92 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: image + image_size: 32 + channels: 4 + cond_stage_trainable: false + concat_mode: false + scale_by_std: true + monitor: val/loss_simple_ema + scheduler_config: + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: + - 10000 + cycle_lengths: + - 10000000000000 + f_start: + - 1.0e-06 + f_max: + - 1.0 + f_min: + - 1.0 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 192 + attention_resolutions: + - 1 + - 2 + - 4 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 2 + - 4 + - 4 + num_heads: 8 + use_scale_shift_norm: true + resblock_updown: true + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: '__is_unconditional__' + +data: + target: main.DataModuleFromConfig + params: + batch_size: 96 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNChurchesTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNChurchesValidation + params: + size: 256 diff --git a/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml b/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml new file mode 100644 index 0000000..1a721cf --- /dev/null +++ b/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml @@ -0,0 +1,59 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: segmentation + image_size: 64 + channels: 3 + concat_mode: true + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 6 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 32 + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 4 + - 8 + num_heads: 8 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.SpatialRescaler + params: + n_stages: 2 + in_channels: 182 + out_channels: 3 diff --git a/stable-diffusion/models/ldm/semantic_synthesis512/config.yaml b/stable-diffusion/models/ldm/semantic_synthesis512/config.yaml new file mode 100644 index 0000000..8faded2 --- /dev/null +++ b/stable-diffusion/models/ldm/semantic_synthesis512/config.yaml @@ -0,0 +1,78 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: segmentation + image_size: 128 + channels: 3 + concat_mode: true + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 128 + in_channels: 6 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 32 + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 4 + - 8 + num_heads: 8 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.SpatialRescaler + params: + n_stages: 2 + in_channels: 182 + out_channels: 3 +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + wrap: false + num_workers: 10 + train: + target: ldm.data.landscapes.RFWTrain + params: + size: 768 + crop_size: 512 + segmentation_to_float32: true + validation: + target: ldm.data.landscapes.RFWValidation + params: + size: 768 + crop_size: 512 + segmentation_to_float32: true diff --git a/stable-diffusion/models/ldm/text2img256/config.yaml b/stable-diffusion/models/ldm/text2img256/config.yaml new file mode 100644 index 0000000..3f54a01 --- /dev/null +++ b/stable-diffusion/models/ldm/text2img256/config.yaml @@ -0,0 +1,77 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 3 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 192 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 5 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 640 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 640 + n_layer: 32 +data: + target: main.DataModuleFromConfig + params: + batch_size: 28 + num_workers: 10 + wrap: false + train: + target: ldm.data.previews.pytorch_dataset.PreviewsTrain + params: + size: 256 + validation: + target: ldm.data.previews.pytorch_dataset.PreviewsValidation + params: + size: 256 diff --git a/stable-diffusion/notebook_helpers.py b/stable-diffusion/notebook_helpers.py new file mode 100644 index 0000000..5d0ebd7 --- /dev/null +++ b/stable-diffusion/notebook_helpers.py @@ -0,0 +1,270 @@ +from torchvision.datasets.utils import download_url +from ldm.util import instantiate_from_config +import torch +import os +# todo ? +from google.colab import files +from IPython.display import Image as ipyimg +import ipywidgets as widgets +from PIL import Image +from numpy import asarray +from einops import rearrange, repeat +import torch, torchvision +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import ismap +import time +from omegaconf import OmegaConf + + +def download_models(mode): + + if mode == "superresolution": + # this is the small bsr light model + url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' + url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' + + path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' + path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' + + download_url(url_conf, path_conf) + download_url(url_ckpt, path_ckpt) + + path_conf = path_conf + '/?dl=1' # fix it + path_ckpt = path_ckpt + '/?dl=1' # fix it + return path_conf, path_ckpt + + else: + raise NotImplementedError + + +def load_model_from_config(config, ckpt): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + model.cuda() + model.eval() + return {"model": model}, global_step + + +def get_model(mode): + path_conf, path_ckpt = download_models(mode) + config = OmegaConf.load(path_conf) + model, step = load_model_from_config(config, path_ckpt) + return model + + +def get_custom_cond(mode): + dest = "data/example_conditioning" + + if mode == "superresolution": + uploaded_img = files.upload() + filename = next(iter(uploaded_img)) + name, filetype = filename.split(".") # todo assumes just one dot in name ! + os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") + + elif mode == "text_conditional": + w = widgets.Text(value='A cake with cream!', disabled=True) + display(w) + + with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: + f.write(w.value) + + elif mode == "class_conditional": + w = widgets.IntSlider(min=0, max=1000) + display(w) + with open(f"{dest}/{mode}/custom.txt", 'w') as f: + f.write(w.value) + + else: + raise NotImplementedError(f"cond not implemented for mode{mode}") + + +def get_cond_options(mode): + path = "data/example_conditioning" + path = os.path.join(path, mode) + onlyfiles = [f for f in sorted(os.listdir(path))] + return path, onlyfiles + + +def select_cond_path(mode): + path = "data/example_conditioning" # todo + path = os.path.join(path, mode) + onlyfiles = [f for f in sorted(os.listdir(path))] + + selected = widgets.RadioButtons( + options=onlyfiles, + description='Select conditioning:', + disabled=False + ) + display(selected) + selected_path = os.path.join(path, selected.value) + return selected_path + + +def get_cond(mode, selected_path): + example = dict() + if mode == "superresolution": + up_f = 4 + visualize_cond_img(selected_path) + + c = Image.open(selected_path) + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(torch.device("cuda")) + example["LR_image"] = c + example["image"] = c_up + + return example + + +def visualize_cond_img(path): + display(ipyimg(filename=path)) + + +def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None): + + example = get_cond(task, selected_path) + + save_intermediate_vid = False + n_runs = 1 + masked = False + guider = None + ckwargs = None + mode = 'ddim' + ddim_use_x0_pred = False + temperature = 1. + eta = 1. + make_progrow = True + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + invert_mask = False + + x_T = None + for n in range(n_runs): + if custom_shape is not None: + x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + mode=mode, custom_steps=custom_steps, + eta=eta, swap_mode=False , masked=masked, + invert_mask=invert_mask, quantize_x0=False, + custom_schedule=None, decode_interval=10, + resize_enabled=resize_enabled, custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, + make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, img_callback=None, + temperature=1., noise_dropout=0., score_corrector=None, + corrector_kwargs=None, x_T=None, log_every_t=None + ): + + ddim = DDIMSampler(model) + bs = shape[0] # dont know where this comes from but wayne + shape = shape[1:] # cut batch dim + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_T=x_T) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, + invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, + resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + log_every_t = 1 if save_intermediate_vid else None + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key =='class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + img_cb = None + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0, + temperature=temperature, noise_dropout=noise_dropout, + score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_T=x_T, log_every_t=log_every_t) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log \ No newline at end of file diff --git a/stable-diffusion/scripts/download_first_stages.sh b/stable-diffusion/scripts/download_first_stages.sh new file mode 100644 index 0000000..a8d79e9 --- /dev/null +++ b/stable-diffusion/scripts/download_first_stages.sh @@ -0,0 +1,41 @@ +#!/bin/bash +wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip +wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip +wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip +wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip +wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip +wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip +wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip + + + +cd models/first_stage_models/kl-f4 +unzip -o model.zip + +cd ../kl-f8 +unzip -o model.zip + +cd ../kl-f16 +unzip -o model.zip + +cd ../kl-f32 +unzip -o model.zip + +cd ../vq-f4 +unzip -o model.zip + +cd ../vq-f4-noattn +unzip -o model.zip + +cd ../vq-f8 +unzip -o model.zip + +cd ../vq-f8-n256 +unzip -o model.zip + +cd ../vq-f16 +unzip -o model.zip + +cd ../.. \ No newline at end of file diff --git a/stable-diffusion/scripts/download_models.sh b/stable-diffusion/scripts/download_models.sh new file mode 100644 index 0000000..84297d7 --- /dev/null +++ b/stable-diffusion/scripts/download_models.sh @@ -0,0 +1,49 @@ +#!/bin/bash +wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip +wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip +wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip +wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip +wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip +wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip +wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip +wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip +wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip + + + +cd models/ldm/celeba256 +unzip -o celeba-256.zip + +cd ../ffhq256 +unzip -o ffhq-256.zip + +cd ../lsun_churches256 +unzip -o lsun_churches-256.zip + +cd ../lsun_beds256 +unzip -o lsun_beds-256.zip + +cd ../text2img256 +unzip -o model.zip + +cd ../cin256 +unzip -o model.zip + +cd ../semantic_synthesis512 +unzip -o model.zip + +cd ../semantic_synthesis256 +unzip -o model.zip + +cd ../bsr_sr +unzip -o model.zip + +cd ../layout2img-openimages256 +unzip -o model.zip + +cd ../inpainting_big +unzip -o model.zip + +cd ../.. diff --git a/stable-diffusion/scripts/img2img.py b/stable-diffusion/scripts/img2img.py new file mode 100644 index 0000000..421e215 --- /dev/null +++ b/stable-diffusion/scripts/img2img.py @@ -0,0 +1,293 @@ +"""make variations of input image""" + +import argparse, os, sys, glob +import PIL +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +from torch import autocast +from contextlib import nullcontext +import time +from pytorch_lightning import seed_everything + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def load_img(path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--init-img", + type=str, + nargs="?", + help="path to the input image" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save indiviual samples. For speed measurements.", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", + ) + parser.add_argument( + "--n_samples", + type=int, + default=2, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.75, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + + opt = parser.parse_args() + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + raise NotImplementedError("PLMS sampler not (yet) supported") + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + assert os.path.isfile(opt.init_img) + init_image = load_img(opt.init_img).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) + + assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(opt.strength * opt.ddim_steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not opt.skip_save: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/stable-diffusion/scripts/inpaint.py b/stable-diffusion/scripts/inpaint.py new file mode 100644 index 0000000..d6e6387 --- /dev/null +++ b/stable-diffusion/scripts/inpaint.py @@ -0,0 +1,98 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + opt = parser.parse_args() + + masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) + images = [x.replace("_mask.png", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + batch = make_batch(image, mask, device=device) + + # encode masked image and concat downsampled mask + c = model.cond_stage_model.encode(batch["masked_image"]) + cc = torch.nn.functional.interpolate(batch["mask"], + size=c.shape[-2:]) + c = torch.cat((c, cc), dim=1) + + shape = (c.shape[1]-1,)+c.shape[2:] + samples_ddim, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + + image = torch.clamp((batch["image"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/stable-diffusion/scripts/knn2img.py b/stable-diffusion/scripts/knn2img.py new file mode 100644 index 0000000..e6eaaec --- /dev/null +++ b/stable-diffusion/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/stable-diffusion/scripts/latent_imagenet_diffusion.ipynb b/stable-diffusion/scripts/latent_imagenet_diffusion.ipynb new file mode 100644 index 0000000..607f94f --- /dev/null +++ b/stable-diffusion/scripts/latent_imagenet_diffusion.ipynb @@ -0,0 +1,429 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "latent-imagenet-diffusion.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Class-Conditional Synthesis with Latent Diffusion Models" + ], + "metadata": { + "id": "NUmmV5ZvrPbP" + } + }, + { + "cell_type": "markdown", + "source": [ + "Install all the requirements" + ], + "metadata": { + "id": "zh7u8gOx0ivw" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NHgUAp48qwoG", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "411d4df6-d91a-42d4-819e-9cf641c12248", + "cellView": "form" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'latent-diffusion'...\n", + "remote: Enumerating objects: 992, done.\u001B[K\n", + "remote: Counting objects: 100% (695/695), done.\u001B[K\n", + "remote: Compressing objects: 100% (397/397), done.\u001B[K\n", + "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", + "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", + "Resolving deltas: 100% (510/510), done.\n", + "Cloning into 'taming-transformers'...\n", + "remote: Enumerating objects: 1335, done.\u001B[K\n", + "remote: Counting objects: 100% (525/525), done.\u001B[K\n", + "remote: Compressing objects: 100% (493/493), done.\u001B[K\n", + "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", + "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", + "Resolving deltas: 100% (267/267), done.\n", + "Obtaining file:///content/taming-transformers\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.10.0+cu111)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.21.5)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (4.63.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->taming-transformers==0.0.1) (3.10.0.2)\n", + "Installing collected packages: taming-transformers\n", + " Running setup.py develop for taming-transformers\n", + "Successfully installed taming-transformers-0.0.1\n", + "\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", + "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" + ] + } + ], + "source": [ + "#@title Installation\n", + "!git clone https://github.com/CompVis/latent-diffusion.git\n", + "!git clone https://github.com/CompVis/taming-transformers\n", + "!pip install -e ./taming-transformers\n", + "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", + "\n", + "import sys\n", + "sys.path.append(\".\")\n", + "sys.path.append('./taming-transformers')\n", + "from taming.models import vqgan " + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now, download the checkpoint (~1.7 GB). This will usually take 1-2 minutes." + ], + "metadata": { + "id": "fNqCqQDoyZmq" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Download\n", + "%cd latent-diffusion/ \n", + "\n", + "!mkdir -p models/ldm/cin256-v2/\n", + "!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt " + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cNHvQBhzyXCI", + "outputId": "0a79e979-8484-4c62-96d9-7c79b1835162", + "cellView": "form" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/latent-diffusion\n", + "--2022-04-03 13:04:51-- https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt\n", + "Resolving ommer-lab.com (ommer-lab.com)... 141.84.41.65\n", + "Connecting to ommer-lab.com (ommer-lab.com)|141.84.41.65|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1827378153 (1.7G)\n", + "Saving to: ‘models/ldm/cin256-v2/model.ckpt’\n", + "\n", + "models/ldm/cin256-v 100%[===================>] 1.70G 24.9MB/s in 70s \n", + "\n", + "2022-04-03 13:06:02 (24.9 MB/s) - ‘models/ldm/cin256-v2/model.ckpt’ saved [1827378153/1827378153]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's also check what type of GPU we've got." + ], + "metadata": { + "id": "ThxmCePqt1mt" + } + }, + { + "cell_type": "code", + "source": [ + "!nvidia-smi" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jbL2zJ7Pt7Jl", + "outputId": "c8242be9-dba2-4a9f-da44-a294a70bb449" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sun Apr 3 13:06:21 2022 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 66C P8 33W / 149W | 0MiB / 11441MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Load it." + ], + "metadata": { + "id": "1tWAqdwk0Nrn" + } + }, + { + "cell_type": "code", + "source": [ + "#@title loading utils\n", + "import torch\n", + "from omegaconf import OmegaConf\n", + "\n", + "from ldm.util import instantiate_from_config\n", + "\n", + "\n", + "def load_model_from_config(config, ckpt):\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " model.cuda()\n", + " model.eval()\n", + " return model\n", + "\n", + "\n", + "def get_model():\n", + " config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n", + " model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n", + " return model" + ], + "metadata": { + "id": "fnGwQRhtyBhb", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from ldm.models.diffusion.ddim import DDIMSampler\n", + "\n", + "model = get_model()\n", + "sampler = DDIMSampler(model)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BPnyd-XUKbfE", + "outputId": "0fcd10e4-0df2-4ab9-cbf5-f08f4902c954" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loading model from models/ldm/cin256-v2/model.ckpt\n", + "LatentDiffusion: Running in eps-prediction mode\n", + "DiffusionWrapper has 400.92 M params.\n", + "making attention of type 'vanilla' with 512 in_channels\n", + "Working with z of shape (1, 3, 64, 64) = 12288 dimensions.\n", + "making attention of type 'vanilla' with 512 in_channels\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And go. Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` variables. As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity. Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`." + ], + "metadata": { + "id": "iIEAhY8AhUrh" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np \n", + "from PIL import Image\n", + "from einops import rearrange\n", + "from torchvision.utils import make_grid\n", + "\n", + "\n", + "classes = [25, 187, 448, 992] # define classes to be sampled here\n", + "n_samples_per_class = 6\n", + "\n", + "ddim_steps = 20\n", + "ddim_eta = 0.0\n", + "scale = 3.0 # for unconditional guidance\n", + "\n", + "\n", + "all_samples = list()\n", + "\n", + "with torch.no_grad():\n", + " with model.ema_scope():\n", + " uc = model.get_learned_conditioning(\n", + " {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n", + " )\n", + " \n", + " for class_label in classes:\n", + " print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", + " xc = torch.tensor(n_samples_per_class*[class_label])\n", + " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", + " \n", + " samples_ddim, _ = sampler.sample(S=ddim_steps,\n", + " conditioning=c,\n", + " batch_size=n_samples_per_class,\n", + " shape=[3, 64, 64],\n", + " verbose=False,\n", + " unconditional_guidance_scale=scale,\n", + " unconditional_conditioning=uc, \n", + " eta=ddim_eta)\n", + "\n", + " x_samples_ddim = model.decode_first_stage(samples_ddim)\n", + " x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", + " min=0.0, max=1.0)\n", + " all_samples.append(x_samples_ddim)\n", + "\n", + "\n", + "# display as grid\n", + "grid = torch.stack(all_samples, 0)\n", + "grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", + "grid = make_grid(grid, nrow=n_samples_per_class)\n", + "\n", + "# to image\n", + "grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", + "Image.fromarray(grid.astype(np.uint8))" + ], + "metadata": { + "id": "jcbqWX2Ytu9t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "3b7adde0-d80e-4c01-82d2-bf988aee7455" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rendering 6 examples of class '25' in 20 steps and using s=3.00.\n", + "Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n", + "Running DDIM Sampling with 20 timesteps\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.89s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rendering 6 examples of class '187' in 20 steps and using s=3.00.\n", + "Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n", + "Running DDIM Sampling with 20 timesteps\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.87s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rendering 6 examples of class '448' in 20 steps and using s=3.00.\n", + "Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n", + "Running DDIM Sampling with 20 timesteps\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.86s/it]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rendering 6 examples of class '992' in 20 steps and using s=3.00.\n", + "Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n", + "Running DDIM Sampling with 20 timesteps\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.86s/it]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "image/png": "\n" + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "92QkRfm0e6K0" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/stable-diffusion/scripts/sample_diffusion.py b/stable-diffusion/scripts/sample_diffusion.py new file mode 100644 index 0000000..876fe3c --- /dev/null +++ b/stable-diffusion/scripts/sample_diffusion.py @@ -0,0 +1,313 @@ +import argparse, os, sys, glob, datetime, yaml +import torch +import time +import numpy as np +from tqdm import trange + +from omegaconf import OmegaConf +from PIL import Image + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + +rescale = lambda x: (x + 1.) / 2. + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == "RGB": + x = x.convert("RGB") + return x + + +def custom_to_np(x): + # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py + sample = x.detach().cpu() + sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + return sample + + +def logs2pil(logs, keys=["sample"]): + imgs = dict() + for k in logs: + try: + if len(logs[k].shape) == 4: + img = custom_to_pil(logs[k][0, ...]) + elif len(logs[k].shape) == 3: + img = custom_to_pil(logs[k]) + else: + print(f"Unknown format for key {k}. ") + img = None + except: + img = None + imgs[k] = img + return imgs + + +@torch.no_grad() +def convsample(model, shape, return_intermediates=True, + verbose=True, + make_prog_row=False): + + + if not make_prog_row: + return model.p_sample_loop(None, shape, + return_intermediates=return_intermediates, verbose=verbose) + else: + return model.progressive_denoising( + None, shape, verbose=True + ) + + +@torch.no_grad() +def convsample_ddim(model, steps, shape, eta=1.0 + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): + + + log = dict() + + shape = [batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size] + + with model.ema_scope("Plotting"): + t0 = time.time() + if vanilla: + sample, progrow = convsample(model, shape, + make_prog_row=True) + else: + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, + eta=eta) + + t1 = time.time() + + x_sample = model.decode_first_stage(sample) + + log["sample"] = x_sample + log["time"] = t1 - t0 + log['throughput'] = sample.shape[0] / (t1 - t0) + print(f'Throughput for this batch: {log["throughput"]}') + return log + +def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): + if vanilla: + print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + else: + print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') + + + tstart = time.time() + n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + # path = logdir + if model.cond_stage_model is None: + all_images = [] + + print(f"Running unconditional sampling for {n_samples} samples") + for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): + logs = make_convolutional_sample(model, batch_size=batch_size, + vanilla=vanilla, custom_steps=custom_steps, + eta=eta) + n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") + all_images.extend([custom_to_np(logs["sample"])]) + if n_saved >= n_samples: + print(f'Finish after generating {n_saved} samples') + break + all_img = np.concatenate(all_images, axis=0) + all_img = all_img[:n_samples] + shape_str = "x".join([str(x) for x in all_img.shape]) + nppath = os.path.join(nplog, f"{shape_str}-samples.npz") + np.savez(nppath, all_img) + + else: + raise NotImplementedError('Currently only sampling for unconditional models supported.') + + print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") + + +def save_logs(logs, path, n_saved=0, key="sample", np_path=None): + for k in logs: + if k == key: + batch = logs[key] + if np_path is None: + for x in batch: + img = custom_to_pil(x) + imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") + img.save(imgpath) + n_saved += 1 + else: + npbatch = custom_to_np(batch) + shape_str = "x".join([str(x) for x in npbatch.shape]) + nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") + np.savez(nppath, npbatch) + n_saved += npbatch.shape[0] + return n_saved + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--resume", + type=str, + nargs="?", + help="load from logdir or checkpoint in logdir", + ) + parser.add_argument( + "-n", + "--n_samples", + type=int, + nargs="?", + help="number of samples to draw", + default=50000 + ) + parser.add_argument( + "-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0 + ) + parser.add_argument( + "-v", + "--vanilla_sample", + default=False, + action='store_true', + help="vanilla sampling (default option is DDIM sampling)?", + ) + parser.add_argument( + "-l", + "--logdir", + type=str, + nargs="?", + help="extra logdir", + default="none" + ) + parser.add_argument( + "-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50 + ) + parser.add_argument( + "--batch_size", + type=int, + nargs="?", + help="the bs", + default=10 + ) + return parser + + +def load_model_from_config(config, sd): + model = instantiate_from_config(config) + model.load_state_dict(sd,strict=False) + model.cuda() + model.eval() + return model + + +def load_model(config, ckpt, gpu, eval_mode): + if ckpt: + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, + pl_sd["state_dict"]) + + return model, global_step + + +if __name__ == "__main__": + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + sys.path.append(os.getcwd()) + command = " ".join(sys.argv) + + parser = get_parser() + opt, unknown = parser.parse_known_args() + ckpt = None + + if not os.path.exists(opt.resume): + raise ValueError("Cannot find {}".format(opt.resume)) + if os.path.isfile(opt.resume): + # paths = opt.resume.split("/") + try: + logdir = '/'.join(opt.resume.split('/')[:-1]) + # idx = len(paths)-paths[::-1].index("logs")+1 + print(f'Logdir is {logdir}') + except ValueError: + paths = opt.resume.split("/") + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + logdir = "/".join(paths[:idx]) + ckpt = opt.resume + else: + assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "model.ckpt") + + base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) + opt.base = base_configs + + configs = [OmegaConf.load(cfg) for cfg in opt.base] + cli = OmegaConf.from_dotlist(unknown) + config = OmegaConf.merge(*configs, cli) + + gpu = True + eval_mode = True + + if opt.logdir != "none": + locallog = logdir.split(os.sep)[-1] + if locallog == "": locallog = logdir.split(os.sep)[-2] + print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") + logdir = os.path.join(opt.logdir, locallog) + + print(config) + + model, global_step = load_model(config, ckpt, gpu, eval_mode) + print(f"global step: {global_step}") + print(75 * "=") + print("logging to:") + logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) + imglogdir = os.path.join(logdir, "img") + numpylogdir = os.path.join(logdir, "numpy") + + os.makedirs(imglogdir) + os.makedirs(numpylogdir) + print(logdir) + print(75 * "=") + + # write config out + sampling_file = os.path.join(logdir, "sampling_config.yaml") + sampling_conf = vars(opt) + + with open(sampling_file, 'w') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + print(sampling_conf) + + + run(model, imglogdir, eta=opt.eta, + vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, + batch_size=opt.batch_size, nplog=numpylogdir) + + print("done.") diff --git a/stable-diffusion/scripts/tests/test_watermark.py b/stable-diffusion/scripts/tests/test_watermark.py new file mode 100644 index 0000000..f93f8a6 --- /dev/null +++ b/stable-diffusion/scripts/tests/test_watermark.py @@ -0,0 +1,18 @@ +import cv2 +import fire +from imwatermark import WatermarkDecoder + + +def testit(img_path): + bgr = cv2.imread(img_path) + decoder = WatermarkDecoder('bytes', 136) + watermark = decoder.decode(bgr, 'dwtDct') + try: + dec = watermark.decode('utf-8') + except: + dec = "null" + print(dec) + + +if __name__ == "__main__": + fire.Fire(testit) \ No newline at end of file diff --git a/stable-diffusion/scripts/train_searcher.py b/stable-diffusion/scripts/train_searcher.py new file mode 100644 index 0000000..1e79048 --- /dev/null +++ b/stable-diffusion/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file diff --git a/stable-diffusion/scripts/txt2img.py b/stable-diffusion/scripts/txt2img.py new file mode 100644 index 0000000..bc38640 --- /dev/null +++ b/stable-diffusion/scripts/txt2img.py @@ -0,0 +1,352 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--dpm_solver", + action='store_true', + help="use dpm_solver sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.dpm_solver: + sampler = DPMSolverSampler(model) + elif opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/stable-diffusion/scripts/txt2realistic_human.py b/stable-diffusion/scripts/txt2realistic_human.py new file mode 100644 index 0000000..437c065 --- /dev/null +++ b/stable-diffusion/scripts/txt2realistic_human.py @@ -0,0 +1,347 @@ +import argparse, os, sys, glob +import cv2 +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--dpm_solver", + action='store_true', + help="use dpm_solver sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.dpm_solver: + sampler = DPMSolverSampler(model) + elif opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/stable-diffusion/setup.py b/stable-diffusion/setup.py new file mode 100644 index 0000000..a24d541 --- /dev/null +++ b/stable-diffusion/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='latent-diffusion', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/.gitignore b/stable-dreamfusion-3DPortrait/.gitignore new file mode 100644 index 0000000..4684816 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/.gitignore @@ -0,0 +1,40 @@ +__pycache__/ +build/ +*.egg-info/ +*.so +venv_*/ + +tmp* +# data/ +ldm/data/ +data2 +scripts2 +trial*/ +.vs/ + +TOKEN +*.ckpt + +densegridencoder +tets/256_tets.npz + +.vscode/launch.json + +data2 +data/car* +data/chair* +data/warrior* +data/wd* +data/space* +data/corgi* +data/turtle* + +# Only keep the original image, not the automatically-generated depth, normals, rgba +data/baby_phoenix_on_ice_* +data/bollywood_actress_* +data/beach_house_1_* +data/beach_house_2_* +data/mona_lisa_* +data/futuristic_car_* +data/church_ruins_* + diff --git a/stable-dreamfusion-3DPortrait/LICENSE b/stable-dreamfusion-3DPortrait/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/stable-dreamfusion-3DPortrait/activation.py b/stable-dreamfusion-3DPortrait/activation.py new file mode 100644 index 0000000..e6cba6a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/activation.py @@ -0,0 +1,21 @@ +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float) + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(max=15)) + +trunc_exp = _trunc_exp.apply + +def biased_softplus(x, bias=0): + return torch.nn.functional.softplus(x - bias) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/assets/advanced.md b/stable-dreamfusion-3DPortrait/assets/advanced.md new file mode 100644 index 0000000..c9432c0 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/assets/advanced.md @@ -0,0 +1,85 @@ + +# Code organization & Advanced tips + +This is a simple description of the most important implementation details. +If you are interested in improving this repo, this might be a starting point. +Any contribution would be greatly appreciated! + +* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`: +```python +## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE. +pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) +## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen. +latents = self.encode_imgs(pred_rgb_512) +... # timestep sampling, noise adding and UNet noise predicting +## 3. the SDS loss +w = (1 - self.alphas[t]) +grad = w * (noise_pred - noise) +# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad: +# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later) +latents.backward(gradient=grad, retain_graph=True) +return 0 # dummy loss + +# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng) +class SpecifyGradient(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. + return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, grad_scale): + gt_grad, = ctx.saved_tensors + gt_grad = gt_grad * grad_scale + return gt_grad, None + +loss = SpecifyGradient.apply(latents, grad) +return loss # functional loss + +# 3.3. reparameterization (credits to @Xallt) +# d(loss)/d(latents) = grad, since grad is already detached, it's this simple. +loss = (grad * latents).sum() +return loss + +# 3.4. reparameterization (credits to threestudio) +# this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative. +targets = (latents - grad).detach() +loss = 0.5 * F.mse_loss(latents, targets, reduction='sum') +return loss +``` +* Other regularizations are in `./nerf/utils.py > Trainer > train_step`. + * The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1). +* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`. +* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`. + * light direction: current implementation use a plane light source, instead of a point light source. +* View-dependent prompting: `./nerf/provider.py > get_view_direction`. + * use `--angle_overhead, --angle_front` to set the border. +* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option. +* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`. + + +# Debugging + +`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one: + +```bash +python main.py --text "a hamburger" --workspace trial -O --vram_O +``` + +... with: + +```bash +debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O +``` + +For more details: https://github.com/bulletmark/debugpy-run + +# Axes and directions of polar, azimuth, etc. in NeRF and Zero123 + +NeRF_Zero123 + +This code refers to theta for polar, phi for azimuth. + diff --git a/stable-dreamfusion-3DPortrait/assets/update_logs.md b/stable-dreamfusion-3DPortrait/assets/update_logs.md new file mode 100644 index 0000000..b1c2e2c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/assets/update_logs.md @@ -0,0 +1,39 @@ +### 2023.4.19 +* Fix depth supervision, migrate depth estimation model to omnidata. +* Add normal supervision (also by omnidata). + +https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4 + +### 2023.4.7 +Improvement on mesh quality & DMTet finetuning support. + +https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4 + +### 2023.3.30 +* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape. + +https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4 + +### 2023.1.30 +* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved. +* More efficient two-pass raymarching in training inspired by nerfacc. + +https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4 + +### 2022.12.3 +* Support Stable-diffusion 2.0 base. + +### 2022.11.15 +* Add the vanilla backbone that is pure-pytorch. + +### 2022.10.9 +* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled. +* Enable shading by default (--latent_iter_ratio 1000). + +### 2022.10.5 +* Basic reproduction finished. +* Non --cuda_ray, --tcnn are not working, need to fix. +* Shading is not working, disabled in utils.py for now. Surface normals are bad. +* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry... + +https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4 diff --git a/stable-dreamfusion-3DPortrait/config/anya.csv b/stable-dreamfusion-3DPortrait/config/anya.csv new file mode 100644 index 0000000..4509748 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/config/anya.csv @@ -0,0 +1,3 @@ +zero123_weight, radius, polar, azimuth, image +1, 3, 90, 0, data/anya_front_rgba.png +1, 3, 90, 180, data/anya_back_rgba.png \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/config/car.csv b/stable-dreamfusion-3DPortrait/config/car.csv new file mode 100644 index 0000000..c014cdc --- /dev/null +++ b/stable-dreamfusion-3DPortrait/config/car.csv @@ -0,0 +1,5 @@ +zero123_weight, radius, polar, azimuth, image +4, 3.2, 90, 0, data/car_left_rgba.png +1, 3, 90, 90, data/car_front_rgba.png +4, 3.2, 90, 180, data/car_right_rgba.png +1, 3, 90, -90, data/car_back_rgba.png \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/config/corgi.csv b/stable-dreamfusion-3DPortrait/config/corgi.csv new file mode 100644 index 0000000..bb1234a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/config/corgi.csv @@ -0,0 +1,2 @@ +zero123_weight, radius, polar, azimuth, image +1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/data/anya_back.webp b/stable-dreamfusion-3DPortrait/data/anya_back.webp new file mode 100644 index 0000000..827bb96 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_back.webp differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_back_depth.png b/stable-dreamfusion-3DPortrait/data/anya_back_depth.png new file mode 100644 index 0000000..3fece86 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_back_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_back_normal.png b/stable-dreamfusion-3DPortrait/data/anya_back_normal.png new file mode 100644 index 0000000..f8550bf Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_back_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_back_rgba.png b/stable-dreamfusion-3DPortrait/data/anya_back_rgba.png new file mode 100644 index 0000000..d583853 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_back_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_front.jpg b/stable-dreamfusion-3DPortrait/data/anya_front.jpg new file mode 100644 index 0000000..588c72d Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_front.jpg differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_front.png b/stable-dreamfusion-3DPortrait/data/anya_front.png new file mode 100644 index 0000000..276bd41 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_front.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_front_depth.png b/stable-dreamfusion-3DPortrait/data/anya_front_depth.png new file mode 100644 index 0000000..a98cc40 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_front_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_front_normal.png b/stable-dreamfusion-3DPortrait/data/anya_front_normal.png new file mode 100644 index 0000000..fedf7f7 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_front_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/anya_front_rgba.png b/stable-dreamfusion-3DPortrait/data/anya_front_rgba.png new file mode 100644 index 0000000..089499e Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/anya_front_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/baby_phoenix_on_ice.png b/stable-dreamfusion-3DPortrait/data/baby_phoenix_on_ice.png new file mode 100644 index 0000000..02a15cf Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/baby_phoenix_on_ice.png differ diff --git a/stable-dreamfusion-3DPortrait/data/beach_house_1.png b/stable-dreamfusion-3DPortrait/data/beach_house_1.png new file mode 100644 index 0000000..cfde250 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/beach_house_1.png differ diff --git a/stable-dreamfusion-3DPortrait/data/beach_house_2.png b/stable-dreamfusion-3DPortrait/data/beach_house_2.png new file mode 100644 index 0000000..5a33e50 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/beach_house_2.png differ diff --git a/stable-dreamfusion-3DPortrait/data/bollywood_actress.png b/stable-dreamfusion-3DPortrait/data/bollywood_actress.png new file mode 100644 index 0000000..4316be3 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/bollywood_actress.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cactus.png b/stable-dreamfusion-3DPortrait/data/cactus.png new file mode 100644 index 0000000..1f89ba8 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cactus.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cactus_depth.png b/stable-dreamfusion-3DPortrait/data/cactus_depth.png new file mode 100644 index 0000000..f086e99 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cactus_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cactus_normal.png b/stable-dreamfusion-3DPortrait/data/cactus_normal.png new file mode 100644 index 0000000..f420869 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cactus_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cactus_rgba.png b/stable-dreamfusion-3DPortrait/data/cactus_rgba.png new file mode 100644 index 0000000..1936f75 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cactus_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cake.png b/stable-dreamfusion-3DPortrait/data/cake.png new file mode 100644 index 0000000..dcfba04 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cake.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cake_depth.png b/stable-dreamfusion-3DPortrait/data/cake_depth.png new file mode 100644 index 0000000..ded7595 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cake_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cake_normal.png b/stable-dreamfusion-3DPortrait/data/cake_normal.png new file mode 100644 index 0000000..c7b99b2 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cake_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/cake_rgba.png b/stable-dreamfusion-3DPortrait/data/cake_rgba.png new file mode 100644 index 0000000..f0ae0b0 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/cake_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/catstatue.png b/stable-dreamfusion-3DPortrait/data/catstatue.png new file mode 100644 index 0000000..7f58741 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/catstatue.png differ diff --git a/stable-dreamfusion-3DPortrait/data/catstatue_depth.png b/stable-dreamfusion-3DPortrait/data/catstatue_depth.png new file mode 100644 index 0000000..a22c328 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/catstatue_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/catstatue_normal.png b/stable-dreamfusion-3DPortrait/data/catstatue_normal.png new file mode 100644 index 0000000..3baf000 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/catstatue_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/catstatue_rgba.png b/stable-dreamfusion-3DPortrait/data/catstatue_rgba.png new file mode 100644 index 0000000..3b44eb5 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/catstatue_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/church_ruins.png b/stable-dreamfusion-3DPortrait/data/church_ruins.png new file mode 100644 index 0000000..951eccf Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/church_ruins.png differ diff --git a/stable-dreamfusion-3DPortrait/data/firekeeper.jpg b/stable-dreamfusion-3DPortrait/data/firekeeper.jpg new file mode 100644 index 0000000..9e57d14 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/firekeeper.jpg differ diff --git a/stable-dreamfusion-3DPortrait/data/firekeeper_depth.png b/stable-dreamfusion-3DPortrait/data/firekeeper_depth.png new file mode 100644 index 0000000..7d56a1f Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/firekeeper_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/firekeeper_normal.png b/stable-dreamfusion-3DPortrait/data/firekeeper_normal.png new file mode 100644 index 0000000..614f8ac Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/firekeeper_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/firekeeper_rgba.png b/stable-dreamfusion-3DPortrait/data/firekeeper_rgba.png new file mode 100644 index 0000000..73430de Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/firekeeper_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/futuristic_car.png b/stable-dreamfusion-3DPortrait/data/futuristic_car.png new file mode 100644 index 0000000..0cfc78f Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/futuristic_car.png differ diff --git a/stable-dreamfusion-3DPortrait/data/hamburger.png b/stable-dreamfusion-3DPortrait/data/hamburger.png new file mode 100644 index 0000000..2dc1268 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/hamburger.png differ diff --git a/stable-dreamfusion-3DPortrait/data/hamburger_depth.png b/stable-dreamfusion-3DPortrait/data/hamburger_depth.png new file mode 100644 index 0000000..f76c80c Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/hamburger_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/hamburger_normal.png b/stable-dreamfusion-3DPortrait/data/hamburger_normal.png new file mode 100644 index 0000000..26f0835 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/hamburger_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/hamburger_rgba.png b/stable-dreamfusion-3DPortrait/data/hamburger_rgba.png new file mode 100644 index 0000000..7cd36c3 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/hamburger_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/data/mona_lisa.png b/stable-dreamfusion-3DPortrait/data/mona_lisa.png new file mode 100644 index 0000000..51f8371 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/mona_lisa.png differ diff --git a/stable-dreamfusion-3DPortrait/data/teddy.png b/stable-dreamfusion-3DPortrait/data/teddy.png new file mode 100644 index 0000000..7bb3a96 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/teddy.png differ diff --git a/stable-dreamfusion-3DPortrait/data/teddy_depth.png b/stable-dreamfusion-3DPortrait/data/teddy_depth.png new file mode 100644 index 0000000..70a35b0 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/teddy_depth.png differ diff --git a/stable-dreamfusion-3DPortrait/data/teddy_normal.png b/stable-dreamfusion-3DPortrait/data/teddy_normal.png new file mode 100644 index 0000000..75d08b2 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/teddy_normal.png differ diff --git a/stable-dreamfusion-3DPortrait/data/teddy_rgba.png b/stable-dreamfusion-3DPortrait/data/teddy_rgba.png new file mode 100644 index 0000000..d3dbf5f Binary files /dev/null and b/stable-dreamfusion-3DPortrait/data/teddy_rgba.png differ diff --git a/stable-dreamfusion-3DPortrait/docker/Dockerfile b/stable-dreamfusion-3DPortrait/docker/Dockerfile new file mode 100644 index 0000000..47fd296 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/docker/Dockerfile @@ -0,0 +1,53 @@ +FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 + +# Remove any third-party apt sources to avoid issues with expiring keys. +RUN rm -f /etc/apt/sources.list.d/*.list + +RUN apt-get update + +RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata + +# Install some basic utilities +RUN apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + python3 \ + python3-pip \ + libglfw3-dev \ + libgles2-mesa-dev \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + + +# Create a working directory +RUN mkdir /app +WORKDIR /app + +RUN cd /app +RUN git clone https://github.com/ashawkey/stable-dreamfusion.git + + +RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 + +WORKDIR /app/stable-dreamfusion + +RUN pip3 install -r requirements.txt +RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/ + +# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer +RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + +RUN pip3 install git+https://github.com/openai/CLIP.git +RUN bash scripts/install_ext.sh + + + + + +# Set the default command to python3 +#CMD ["python3"] + diff --git a/stable-dreamfusion-3DPortrait/docker/README.md b/stable-dreamfusion-3DPortrait/docker/README.md new file mode 100644 index 0000000..2fe00e4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/docker/README.md @@ -0,0 +1,80 @@ +### Docker installation + +## Build image +To build the docker image on your own machine, which may take 15-30 mins: +``` +docker build -t stable-dreamfusion:latest . +``` + +If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker. +``` +sudo apt-get install nvidia-container-runtime +``` +Then edit `/etc/docker/daemon.json` and add the default-runtime: +``` +{ + "runtimes": { + "nvidia": { + "path": "nvidia-container-runtime", + "runtimeArgs": [] + } + }, + "default-runtime": "nvidia" +} +``` +And restart docker: +``` +sudo systemctl restart docker +``` +Now you can build tiny-cuda-nn inside docker. + +## Download image +To download the image (~6GB) instead: +``` +docker pull supercabb/stable-dreamfusion:3080_0.0.1 +docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion +``` + +## Use image + +You can launch an interactive shell inside the container: + +``` +docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash +``` +From this shell, all the code in the repo should work. + +To run any single command `` inside the docker container: +``` +docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "" +``` +To train: +``` +export TOKEN="#HUGGING FACE ACCESS TOKEN#" +docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \ +&& python3 main.py --text \"a hamburger\" --workspace trial -O" + +``` +Run test without gui: +``` +export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#" +docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \ +-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \ +main.py --workspace trial -O --test" +``` +Run test with gui: +``` +export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#" +xhost + +docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \ +-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \ +main.py --workspace trial -O --test --gui" +xhost - +``` + + + + + + + diff --git a/stable-dreamfusion-3DPortrait/dpt.py b/stable-dreamfusion-3DPortrait/dpt.py new file mode 100644 index 0000000..8cc0479 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/dpt.py @@ -0,0 +1,924 @@ +import math +import types + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import timm + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) + + +def unflatten_with_named_tensor(input, dim, sizes): + """Workaround for unflattening with named tensor.""" + # tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538 + new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:] + return input.view(*new_shape) + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + + unflattened_dim = 2 + unflattened_size = ( + int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')), + int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')), + ) + unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size)) + + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size) + if layer_4.ndim == 3: + layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(posemb_grid.shape[0])) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor') + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + True, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/encoding.py b/stable-dreamfusion-3DPortrait/encoding.py new file mode 100644 index 0000000..7edd096 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/encoding.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FreqEncoder_torch(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + self.N_freqs = N_freqs + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, max_level=None, **kwargs): + + if max_level is None: + max_level = self.N_freqs + else: + max_level = int(max_level * self.N_freqs) + + out = [] + if self.include_input: + out.append(input) + + for i in range(max_level): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + # append 0 + if self.N_freqs - max_level > 0: + out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype)) + + out = torch.cat(out, dim=-1) + + return out + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear', + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency_torch': + encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) + + elif encoding == 'frequency': # CUDA implementation, faster than torch. + from freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'sphere_harmonics': + from shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation) + + elif encoding == 'hashgrid_taichi': + from taichi_modules.hash_encoder import HashEncoderTaichi + encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/evaluation/Prompt.py b/stable-dreamfusion-3DPortrait/evaluation/Prompt.py new file mode 100644 index 0000000..53603db --- /dev/null +++ b/stable-dreamfusion-3DPortrait/evaluation/Prompt.py @@ -0,0 +1,91 @@ +import textwrap +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification +from transformers import pipeline +import argparse +import sys +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + + +#python Prompt.py --text "a dog is in front of a rabbit" --model vlt5 + + +if __name__ == '__main__': + + # Mimic the calling part of the main, using + parser = argparse.ArgumentParser() + parser.add_argument('--text', default="", type=str, help="text prompt") + #parser.add_argument('--workspace', default="trial", type=str, help="workspace") + parser.add_argument('--model', default='vlt5', type=str, help="model choices - vlt5, bert, XLNet") + + opt = parser.parse_args() + + if opt.model == "vlt5": + tokenizer = AutoTokenizer.from_pretrained("Voicelab/vlt5-base-keywords") + model = AutoModelForSeq2SeqLM.from_pretrained("Voicelab/vlt5-base-keywords") + + task_prefix = "Keywords: " + inputs = [ + opt.text + ] + + for sample in inputs: + input_sequences = [task_prefix + sample] + input_ids = tokenizer( + input_sequences, return_tensors="pt", truncation=True + ).input_ids + output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4) + output_text = tokenizer.decode(output[0], skip_special_tokens=True) + #print(sample, "\n --->", output_text) + + elif opt.model == "bert": + tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor") + model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor") + + text = opt.text + input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") + + # Classify tokens + outputs = model(input_ids) + predictions = outputs.logits.detach().numpy()[0] + labels = predictions.argmax(axis=1) + labels = labels[1:-1] + + print(labels) + tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) + tokens = tokens[1:-1] + output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0] + output_text = tokenizer.convert_tokens_to_string(output_tokens) + + #print(output_text) + + + elif opt.model == "XLNet": + tokenizer = AutoTokenizer.from_pretrained("jasminejwebb/KeywordIdentifier") + model = AutoModelForTokenClassification.from_pretrained("jasminejwebb/KeywordIdentifier") + + text = opt.text + input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") + + # Classify tokens + outputs = model(input_ids) + predictions = outputs.logits.detach().numpy()[0] + labels = predictions.argmax(axis=1) + labels = labels[1:-1] + + print(labels) + tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) + tokens = tokens[1:-1] + output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0] + output_text = tokenizer.convert_tokens_to_string(output_tokens) + + #print(output_text) + +wrapped_text = textwrap.fill(output_text, width=50) + + +print('+' + '-'*52 + '+') +for line in wrapped_text.split('\n'): + print('| {} |'.format(line.ljust(50))) +print('+' + '-'*52 + '+') +#print(result) diff --git a/stable-dreamfusion-3DPortrait/evaluation/mesh_to_video.py b/stable-dreamfusion-3DPortrait/evaluation/mesh_to_video.py new file mode 100644 index 0000000..5810320 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/evaluation/mesh_to_video.py @@ -0,0 +1,87 @@ +import os +import numpy as np +import trimesh +import argparse +from pathlib import Path +from tqdm import tqdm +import pyvista as pv + +def render_video(anim_mesh): + center = anim_mesh.center_mass + plotter = pv.Plotter(off_screen=True) + plotter.add_mesh(anim_mesh) + + radius = 10 + n_frames = 360 + angle_step = 2 * np.pi / n_frames + for i in tqdm(range(n_frames)): + camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]] + plotter.camera_position = (camera_pos, center, (0, 0, 1)) + plotter.show(screenshot=f'frame_{i}.png', auto_close=False) + plotter.close() + os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i "result/frame_%d.png" -vcodec libx264 -crf 25 -pix_fmt yuv420p result/output.mp4') + + + +def generate_mesh(obj1,obj2,transform_vector): + + # Read 2 objects + filename1 = obj1 # Central Object + filename2 = obj2 # Surrounding Object + mesh1 = trimesh.load_mesh(filename1) + mesh2 = trimesh.load_mesh(filename2) + + extents1 = mesh1.extents + extents2 = mesh1.extents + + radius1 = sum(extents1) / 3.0 + radius2 = sum(extents2) / 3.0 + + center1 = mesh1.center_mass + center2 = mesh2.center_mass + + # Move + T1 = -center1 + new =[] + for i in transform_vector: + try: + new.append(float(i))*radius1 + except: + pass + transform_vector = new + print(T1, transform_vector, radius1) + T2 = -center2 + transform_vector + + # Transform + mesh1.apply_translation(T1) + mesh2.apply_translation(T2) + + # merge mesh + merged_mesh = trimesh.util.concatenate((mesh1, mesh2)) + + # save mesh + merged_mesh.export('merged_mesh.obj') + print("----> merge mesh done") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate rotating mesh animation.') + parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.') + parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.') + parser.add_argument('--transform_vector', help='Transform_vector.') + parser.add_argument('--output_file', type=str, default="result/Demo.mp4", help='Output MP4 file.') + parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.') + args = parser.parse_args() + + #mesh = obj.Obj("wr.obj") + generate_mesh(args.center_obj,args.surround_obj,args.transform_vector) + + input_file = Path("merged_mesh.obj") + output_file = Path(args.output_file) + + out_dir = output_file.parent.joinpath('frames') + out_dir.mkdir(parents=True, exist_ok=True) + + anim_mesh = trimesh.load_mesh(str(input_file)) + + render_video(anim_mesh) + diff --git a/stable-dreamfusion-3DPortrait/evaluation/r_precision.py b/stable-dreamfusion-3DPortrait/evaluation/r_precision.py new file mode 100644 index 0000000..d2177ed --- /dev/null +++ b/stable-dreamfusion-3DPortrait/evaluation/r_precision.py @@ -0,0 +1,30 @@ +from sentence_transformers import SentenceTransformer, util +from PIL import Image +import argparse +import sys + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--text', default="", type=str, help="text prompt") + parser.add_argument('--workspace', default="trial", type=str, help="text prompt") + parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path") + parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()") + parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt") + + opt = parser.parse_args() + + #Load CLIP model + model = SentenceTransformer(f'{opt.clip}') + + #Encode an image: + img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png')) + + #Encode text descriptions + text_emb = model.encode([f'{opt.text}']) + + #Compute cosine similarities + cos_scores = util.cos_sim(img_emb, text_emb) + print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy()) + diff --git a/stable-dreamfusion-3DPortrait/evaluation/readme.md b/stable-dreamfusion-3DPortrait/evaluation/readme.md new file mode 100644 index 0000000..b54557c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/evaluation/readme.md @@ -0,0 +1,36 @@ +### Improvement: + +- Usage + + - r_precision.py
+ For prompt seperation
+ --text is for the prompt following the author of stable dream fusion
+ --workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion
+ --latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training
+ --mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines.
+ --clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper
+ + ```bash + python Prompt.py --text "matte painting of a castle made of cheesecake surrounded by a moat made of ice cream" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32 + ``` + + - Prompt.py (model name case sensitive)
+ For prompt seperation

+ --text is for the prompt following the author of stable dream fusion
+ --model is for choose the pretrain models
+ + ```bash + python Prompt.py --text "a dog is in front of a rabbit" --model vlt5 + python Prompt.py --text "a dog is in front of a rabbit" --model bert + python Prompt.py --text "a dog is in front of a rabbit" --model XLNet + ``` + + + - mesh_to_video.py
+ --center_obj IS THE CENTER OBJECT
+ --surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE
+ --transform_vector THE X Y Z 3d vector for transform
+ + ```bash + python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0] + ``` diff --git a/stable-dreamfusion-3DPortrait/fit_latent_trigrid.py b/stable-dreamfusion-3DPortrait/fit_latent_trigrid.py new file mode 100644 index 0000000..1a35273 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/fit_latent_trigrid.py @@ -0,0 +1,467 @@ +import os + +import torch +import argparse +import pandas as pd +import sys + +from nerf.provider import NeRFDataset +from nerf.trigrid_utils import * + +if __name__ == '__main__': + # See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre + class LoadFromFile(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + with values as f: + # parse arguments in the file and store them in the target namespace + parser.parse_args(f.read().split(), namespace) + + + parser = argparse.ArgumentParser() + parser.add_argument('--file', type=open, action=LoadFromFile, help="specify a file filled with more arguments") + parser.add_argument('--text', default=None, help="text prompt") + parser.add_argument('--negative', default='', type=str, help="negative text prompt") + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray") + parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views") + parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs") + parser.add_argument('--test_interval', type=int, default=5, help="test on the test set every interval epochs") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', default=None) + + parser.add_argument('--image', default=None, help="image prompt") + parser.add_argument('--image_config', default=None, help="image config csv") + + parser.add_argument('--known_view_interval', type=int, default=4, + help="train default view with RGB loss every & iters, only valid if --image is not None.") + + parser.add_argument('--IF', action='store_true', + help="experimental: use DeepFloyd IF as the guidance model for nerf stage") + + parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model') + parser.add_argument('--guidance_scale', type=float, default=100, + help="diffusion model classifier-free guidance scale") + + parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture") + parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh") + parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation") + + parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning") + parser.add_argument('--tet_grid_size', type=int, default=128, help="tet grid size") + parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet") + parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry") + + ## Perp-Neg options + parser.add_argument('--perpneg', action='store_true', help="use perp_neg") + parser.add_argument('--negative_w', type=float, default=-2, + help="The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt") + parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt") + parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt") + + ## Trigrid options + parser.add_argument('--trigrid_path', type=str, default='', help="path to trigrid") + parser.add_argument('--trigrid_decoder_ckpt', type=str, default='', help="path to trigrid decoder ckpt") + parser.add_argument('--train_decoder', action='store_true', help="train trigrid decoder") + parser.add_argument('--learnable_bg', action='store_true', help="Learnable background") + parser.add_argument('--noise_bg', action='store_true', help="use noise background") + + ### training options + parser.add_argument('--iters', type=int, default=10000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") + parser.add_argument('--ckpt', type=str, default='latest', + help="possible options are ['latest', 'scratch', 'best', 'latest_model']") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching") + parser.add_argument('--max_steps', type=int, default=1024, + help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=64, + help="num steps sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=32, + help="num steps up-sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, + help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, + help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") + parser.add_argument('--latent_iter_ratio', type=float, default=0.2, + help="training iters that only use albedo shading") + parser.add_argument('--albedo_iter_ratio', type=float, default=0, + help="training iters that only use albedo shading") + parser.add_argument('--min_ambient_ratio', type=float, default=0.1, + help="minimum ambient ratio to use in lambertian shading") + parser.add_argument('--textureless_ratio', type=float, default=0.2, help="ratio of textureless shading") + parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses") + parser.add_argument('--jitter_center', type=float, default=0.2, + help="amount of jitter to add to sampled camera pose's center (camera location)") + parser.add_argument('--jitter_target', type=float, default=0.2, + help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')") + parser.add_argument('--jitter_up', type=float, default=0.02, + help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')") + parser.add_argument('--uniform_sphere_rate', type=float, default=0, + help="likelihood of sampling camera location uniformly on the sphere surface area") + parser.add_argument('--grad_clip', type=float, default=-1, + help="clip grad of all grad to this limit, negative value disables it") + parser.add_argument('--grad_clip_rgb', type=float, default=-1, + help="clip grad of rgb space grad to this limit, negative value disables it") + # model options + parser.add_argument('--bg_radius', type=float, default=3.0, + help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], + help="density activation function") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--blob_density', type=float, default=5, help="max (center) density for the density blob") + parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob") + # network backbone + parser.add_argument('--backbone', type=str, default='trigrid', choices=['trigrid'], help="nerf backbone") + parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer") + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], + help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + # try this if CUDA OOM + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled. + parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training") + parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training") + parser.add_argument('--known_view_scale', type=float, default=1.5, + help="multiply --h/w by this for known view rendering") + parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, + help="random camera noise added to rays_o and rays_d") + parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning") + parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF") + + ### dataset options + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)") + parser.add_argument('--dt_gamma', type=float, default=0, + help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.01, help="minimum near distance for camera") + + parser.add_argument('--radius_range', type=float, nargs='*', default=[2.7, 2.71], + help="training camera radius range") + parser.add_argument('--theta_range', type=float, nargs='*', default=[60, 105], + help="training camera range along the polar angles (i.e. up and down). See advanced.md for details.") + parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], + help="training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.") + parser.add_argument('--fovy_range', type=float, nargs='*', default=[11.5, 21], help="training camera fovy range") + + parser.add_argument('--default_radius', type=float, default=2.7, help="radius for the default view") + parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view") + parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view") + parser.add_argument('--default_fovy', type=float, default=12., help="fovy for the default view") + + parser.add_argument('--progressive_view', action='store_true', + help="progressively expand view sampling range from default to full") + parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, + help="initial ratio of final range, used for progressive_view") + + parser.add_argument('--progressive_level', action='store_true', + help="progressively increase gridencoder's max_level") + + parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region") + parser.add_argument('--angle_front', type=float, default=60, + help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") + parser.add_argument('--t_range', type=float, nargs='+', default=[0.02, 0.98], + help="stable diffusion time steps range") + parser.add_argument('--dont_override_stuff', action='store_true', help="Don't override t_range, etc.") + + ### regularizations + parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy") + parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value") + parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation") + parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation") + parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale") + + parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness") + parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help="loss scale for mesh laplacian") + + parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS") + parser.add_argument('--lambda_rgb', type=float, default=1000, help="loss scale for RGB") + parser.add_argument('--lambda_mask', type=float, default=500, help="loss scale for mask (alpha)") + parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale for normal map") + parser.add_argument('--lambda_depth', type=float, default=10, help="loss scale for relative depth") + parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, + help="loss scale for 2D normal image smoothness") + parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, + help="loss scale for 3D normal image smoothness") + + ### debugging options + parser.add_argument('--save_guidance', action='store_true', + help="save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!") + parser.add_argument('--save_guidance_interval', type=int, default=10, help="save guidance every X step") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=800, help="GUI width") + parser.add_argument('--H', type=int, default=800, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=20, help="default GUI camera fovy") + parser.add_argument('--light_theta', type=float, default=60, + help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") + parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth") + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") + + parser.add_argument('--zero123_config', type=str, + default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', + help="config file for zero123") + parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', + help="ckpt for zero123") + parser.add_argument('--zero123_grad_scale', type=str, default='angle', + help="whether to scale the gradients based on 'angle' or 'None'") + + parser.add_argument('--dataset_size_train', type=int, default=100, + help="Length of train dataset i.e. # of iterations per epoch") + parser.add_argument('--dataset_size_valid', type=int, default=8, + help="# of frames to render in the turntable video in validation") + parser.add_argument('--dataset_size_test', type=int, default=100, + help="# of frames to render in the turntable video at test time") + + parser.add_argument('--exp_start_iter', type=int, default=None, + help="start iter # for experiment, to calculate progressive_view and progressive_level") + parser.add_argument('--exp_end_iter', type=int, default=None, + help="end iter # for experiment, to calculate progressive_view and progressive_level") + + opt = parser.parse_args() + if opt.O: + raise NotImplementedError + opt.fp16 = True + opt.cuda_ray = True + + elif opt.O2: + raise NotImplementedError + opt.fp16 = True + opt.backbone = 'vanilla' + opt.progressive_level = True + + if opt.IF: + if 'SD' in opt.guidance: + opt.guidance.remove('SD') + opt.guidance.append('IF') + opt.latent_iter_ratio = 0 # must not do as_latent + + opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], [] + opt.default_zero123_w = 1 + + opt.exp_start_iter = opt.exp_start_iter or 0 + opt.exp_end_iter = opt.exp_end_iter or opt.iters + + # parameters for image-conditioned generation + if opt.image is not None or opt.image_config is not None: + + if opt.text is None: + # use zero123 guidance model when only providing image + opt.guidance = ['zero123'] + if not opt.dont_override_stuff: + opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov + opt.guidance_scale = 5 + opt.lambda_3d_normal_smooth = 10 + else: + # use stable-diffusion when providing both text and image + opt.guidance = ['SD', 'clip'] + + if not opt.dont_override_stuff: + opt.guidance_scale = 10 + opt.t_range = [0.2, 0.6] + opt.known_view_interval = 2 + opt.lambda_3d_normal_smooth = 20 + opt.bg_radius = -1 + + # smoothness + opt.lambda_entropy = 1 + opt.lambda_orient = 1 + + # latent warmup is not needed + opt.latent_iter_ratio = 0 + if not opt.dont_override_stuff: + opt.albedo_iter_ratio = 0 + + # make shape init more stable + opt.progressive_view = True + opt.progressive_level = True + + if opt.image is not None: + opt.images += [opt.image] + opt.ref_radii += [opt.default_radius] + opt.ref_polars += [opt.default_polar] + opt.ref_azimuths += [opt.default_azimuth] + opt.zero123_ws += [opt.default_zero123_w] + + if opt.image_config is not None: + # for multiview (zero123) + conf = pd.read_csv(opt.image_config, skipinitialspace=True) + opt.images += list(conf.image) + opt.ref_radii += list(conf.radius) + opt.ref_polars += list(conf.polar) + opt.ref_azimuths += list(conf.azimuth) + opt.zero123_ws += list(conf.zero123_weight) + if opt.image is None: + opt.default_radius = opt.ref_radii[0] + opt.default_polar = opt.ref_polars[0] + opt.default_azimuth = opt.ref_azimuths[0] + opt.default_zero123_w = opt.zero123_ws[0] + + # reset to None + if len(opt.images) == 0: + opt.images = None + + if opt.learnable_bg: + assert opt.bg_radius > max( + opt.radius_range), f"bg_radius must be larger than max(radius_range) = {max(opt.radius_range)}" + assert opt.noise_bg is False + + if opt.noise_bg: + assert opt.learnable_bg is False + + assert opt.latent_iter_ratio == 1.0, "latent_iter_ratio must be 1.0 for now" + # default parameters for finetuning + if opt.dmtet: + + opt.h = int(opt.h * opt.dmtet_reso_scale) + opt.w = int(opt.w * opt.dmtet_reso_scale) + opt.known_view_scale = 1 + + if not opt.dont_override_stuff: + opt.t_range = [0.02, 0.50] # ref: magic3D + + if opt.images is not None: + + opt.lambda_normal = 0 + opt.lambda_depth = 0 + + if opt.text is not None and not opt.dont_override_stuff: + opt.t_range = [0.20, 0.50] + + # assume finetuning + opt.latent_iter_ratio = 0 + opt.albedo_iter_ratio = 0 + opt.progressive_view = False + # opt.progressive_level = False + os.makedirs(opt.workspace, exist_ok=True) + # record full range for progressive view expansion + if opt.progressive_view: + if not opt.dont_override_stuff: + # disable as they disturb progressive view + opt.jitter_pose = False + + opt.uniform_sphere_rate = 0 + # back up full range + opt.full_radius_range = opt.radius_range + opt.full_theta_range = opt.theta_range + opt.full_phi_range = opt.phi_range + opt.full_fovy_range = opt.fovy_range + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + from nerf.network_trigrid import NeRFNetwork + from nerf.network_trigrid_latent import LatentNeRFNetwork + + model = LatentNeRFNetwork( + opt=opt, + device=device + ) + teacher_model = NeRFNetwork( + opt=opt, + device=device + ) + # load + print(f'loading trigrid_renderer from {opt.trigrid_decoder_ckpt}...') + ckpt = torch.load(opt.trigrid_decoder_ckpt, map_location=lambda storage, loc: storage) + # ckpt = {'params': params, 'state_dict': ckpt} + state_dict = ckpt['state_dict'] + state_dict_wo_torgb = {} + for k, v in state_dict.items(): + if 'torgb' not in k: + state_dict_wo_torgb[k] = v + + model.model.load_state_dict(state_dict_wo_torgb, strict=False) + teacher_model.model.load_state_dict(state_dict) + # + model.model.rendering_kwargs['depth_resolution'] = 48 + model.model.rendering_kwargs['depth_resolution_importance'] = 48 + model.model.rendering_kwargs['ray_start'] = 2.35 + + teacher_model.model.rendering_kwargs['depth_resolution'] = 48 + teacher_model.model.rendering_kwargs['depth_resolution_importance'] = 48 + teacher_model.model.rendering_kwargs['ray_start'] = 2.35 + # + # load plane from pkl + print(f'loading trigrid from {opt.trigrid_path}...') + import pickle + + with open(opt.trigrid_path, 'rb') as f: + data = pickle.load(f) + trigrid = data['trigrids'] + ws = data['ws'] + model.load_state_dict( + { + 'trigrid': trigrid, + 'ws': ws, + }, strict=False + ) + teacher_model.load_state_dict( + { + 'trigrid': trigrid, + 'ws': ws, + }, strict=False + ) + # print(f'loading encoder from {opt.encoder_ckpt}...') + # encoder_ckpt = torch.load('./pretrained/encoder_sd1.5.pt', map_location=lambda storage, loc: storage) + # model.latent_net.load_state_dict(encoder_ckpt, strict=False) + + print('save trigrid to workspace...') + shutil.copy(opt.trigrid_path, os.path.join(opt.workspace, 'trigrid.pkl')) + + train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, + size=opt.dataset_size_train * opt.batch_size, teacher_H=opt.h * 8, + teacher_W=opt.w * 8).dataloader() + + if opt.optim == 'adan': + from optimizer import Adan + + # Adan usually requires a larger LR + optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, + weight_decay=2e-5, max_grad_norm=5.0, + foreach=False) + else: # adam + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), + betas=(0.9, 0.99), eps=1e-15) + + if opt.backbone == 'vanilla': + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, + lambda iter: 0.1 ** min(iter / opt.iters, 1)) + else: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed + # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + guidance = nn.ModuleDict() + from guidance.sd_utils import StableDiffusion + + guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range) + + trainer = TrigridTrainer(' '.join(sys.argv), 'trigrid', opt, model, teacher_model, guidance, device=device, + workspace=opt.workspace, + optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, + use_checkpoint=opt.ckpt, scheduler_update_every_step=True) + + trainer.default_view_data = train_loader._data.get_default_view_data() + + valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, + size=opt.dataset_size_valid, teacher_H=opt.h * 8, teacher_W=opt.w * 8).dataloader( + batch_size=1) + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test, + teacher_H=opt.h * 8, teacher_W=opt.w * 8).dataloader( + batch_size=1) + + # # test output + # trainer.test(test_loader, save_path=os.path.join(opt.workspace, 'latent_trigrid_fit_initiation')) + + # TO BE DEBUGGED + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, test_loader, max_epoch) + + + + + + + diff --git a/stable-dreamfusion-3DPortrait/freqencoder/__init__.py b/stable-dreamfusion-3DPortrait/freqencoder/__init__.py new file mode 100644 index 0000000..69ec49c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/backend.py b/stable-dreamfusion-3DPortrait/freqencoder/backend.py new file mode 100644 index 0000000..fa0e820 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/backend.py @@ -0,0 +1,42 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/freq.py b/stable-dreamfusion-3DPortrait/freqencoder/freq.py new file mode 100644 index 0000000..5cba1e6 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/setup.py b/stable-dreamfusion-3DPortrait/freqencoder/setup.py new file mode 100644 index 0000000..ea64112 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/setup.py @@ -0,0 +1,52 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/src/bindings.cpp b/stable-dreamfusion-3DPortrait/freqencoder/src/bindings.cpp new file mode 100644 index 0000000..bb5f285 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.cu b/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.cu new file mode 100644 index 0000000..072da74 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.h b/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.h new file mode 100644 index 0000000..34f28c7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/__init__.py b/stable-dreamfusion-3DPortrait/gridencoder/__init__.py new file mode 100644 index 0000000..f1476ce --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/backend.py b/stable-dreamfusion-3DPortrait/gridencoder/backend.py new file mode 100644 index 0000000..b403f34 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/grid.py b/stable-dreamfusion-3DPortrait/gridencoder/grid.py new file mode 100644 index 0000000..3f91daf --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/grid.py @@ -0,0 +1,206 @@ +import math +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1) + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + # zero init if we only calculate partial levels + if max_level < L: outputs.zero_() + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + if max_level < L: dy_dx.zero_() + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1, max_level=None): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # max_level: only calculate first max_level levels (None will use all levels) + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) + + @torch.cuda.amp.autocast(enabled=False) + def grad_weight_decay(self, weight=0.1): + # level-wise meaned weight decay (ref: zip-nerf) + + B = self.embeddings.shape[0] # size of embedding + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/setup.py b/stable-dreamfusion-3DPortrait/gridencoder/setup.py new file mode 100644 index 0000000..a91b0c1 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/src/bindings.cpp b/stable-dreamfusion-3DPortrait/gridencoder/src/bindings.cpp new file mode 100644 index 0000000..fc3dd5e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/src/bindings.cpp @@ -0,0 +1,10 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); + m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)"); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.cu b/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000..93f5b80 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.cu @@ -0,0 +1,713 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= resolution; + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H); + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + // align_corners + if (align_corners) { + pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1] + pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2] + } else { + pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1] + pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1] + } + pos[d] -= (float)pos_grid[d]; + + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; + } + } + + // verification of alignment + // if (level == L - 1 && b < 4) { + // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + // } + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1); + } + } + + uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = (float)(align_corners ? resolution - 1 : resolution); + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1); + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1); + uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H); + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + // align_corners + if (align_corners) { + pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1] + pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2] + } else { + pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1] + pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1] + } + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1); + } + } + + uint32_t index = get_grid_index(gridtype, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 16: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 32: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 16: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 32: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H); + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + // align_corners + if (align_corners) { + pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1] + pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2] + } else { + pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1] + pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1] + } + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 16: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 32: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} + +template +__global__ void kernel_grad_wd( + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const uint32_t C +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B * C) return; + + // locate + grid += b; + grad += b; + + // decide in which level is this thread... + uint32_t level = 0; + const uint32_t n = b / C; + // binary search b in offsets + uint32_t l = 0, r = L; + while (l < r) { + uint32_t m = (l + r) / 2; + if (offsets[m] <= n) { + level = m; + l = m + 1; + } else { + r = m; + } + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + grad[0] += 2 * weight * grid[0] / hashmap_size; +} + +void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_weight_decay", ([&] { + static constexpr uint32_t N_THREAD = 1024; + const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 }; + kernel_grad_wd<<>>(embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, L, C); + })); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.h b/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.h new file mode 100644 index 0000000..3df2e08 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/gridencoder/src/gridencoder.h @@ -0,0 +1,18 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); +void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L); + +#endif \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/guidance/clip_utils.py b/stable-dreamfusion-3DPortrait/guidance/clip_utils.py new file mode 100644 index 0000000..f36295d --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/clip_utils.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +import clip + +class CLIP(nn.Module): + def __init__(self, device, **kwargs): + super().__init__() + + self.device = device + self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False) + + self.aug = T.Compose([ + T.Resize((224, 224)), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + def get_text_embeds(self, prompt, **kwargs): + + text = clip.tokenize(prompt).to(self.device) + text_z = self.clip_model.encode_text(text) + text_z = text_z / text_z.norm(dim=-1, keepdim=True) + + return text_z + + def get_img_embeds(self, image, **kwargs): + + image_z = self.clip_model.encode_image(self.aug(image)) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) + + return image_z + + + def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs): + """ + Args: + grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item. + """ + # TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`. + image_z = self.clip_model.encode_image(self.aug(pred_rgb)) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + + loss = 0 + if 'image' in clip_z: + loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean() + + if 'text' in clip_z: + loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean() + + return loss + diff --git a/stable-dreamfusion-3DPortrait/guidance/if_utils.py b/stable-dreamfusion-3DPortrait/guidance/if_utils.py new file mode 100644 index 0000000..c610b74 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/if_utils.py @@ -0,0 +1,234 @@ +from transformers import logging +from diffusers import IFPipeline, DDPMScheduler + +# suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.cuda.amp import custom_bwd, custom_fwd +from .perpneg_utils import weighted_perpendicular_aggregator + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +class IF(nn.Module): + def __init__(self, device, vram_O, t_range=[0.02, 0.98]): + super().__init__() + + self.device = device + + print(f'[INFO] loading DeepFloyd IF-I-XL...') + + model_key = "DeepFloyd/IF-I-XL-v1.0" + + is_torch2 = torch.__version__[0] == '2' + + # Create model + pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16) + if not is_torch2: + pipe.enable_xformers_memory_efficient_attention() + + if vram_O: + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.unet = pipe.unet + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + self.scheduler = pipe.scheduler + + self.pipe = pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded DeepFloyd IF-I-XL!') + + @torch.no_grad() + def get_text_embeds(self, prompt): + # prompt: [str] + + # TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28 + prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) + inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt') + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + + return embeddings + + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1): + + # [0, 1] to [-1, 1] and make sure shape is [64, 64] + images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(images) + images_noisy = self.scheduler.add_noise(images, noise, t) + + # pred noise + model_input = torch.cat([images_noisy] * 2) + model_input = self.scheduler.scale_model_input(model_input, t) + tt = torch.cat([t] * 2) + noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # TODO: how to use the variance here? + # noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (images - grad).detach() + loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0] + + return loss + + def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, grad_scale=1): + + B = pred_rgb.shape[0] + K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts + + # [0, 1] to [-1, 1] and make sure shape is [64, 64] + images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(images) + images_noisy = self.scheduler.add_noise(images, noise, t) + + # pred noise + model_input = torch.cat([images_noisy] * (1 + K)) + model_input = self.scheduler.scale_model_input(model_input, t) + tt = torch.cat([t] * (1 + K)) + unet_output = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample + noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:] + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1) + noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B) + + + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (images - grad).detach() + loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0] + + return loss + + @torch.no_grad() + def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5): + + images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype) + images = images * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + model_input = torch.cat([images] * 2) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + images = self.scheduler.step(noise_pred, t, images).prev_sample + + images = (images + 1) / 2 + + return images + + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768] + neg_embeds = self.get_text_embeds(negative_prompts) + text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + # Text embeds -> img + imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + return imgs + + +if __name__ == '__main__': + + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + parser.add_argument('-H', type=int, default=64) + parser.add_argument('-W', type=int, default=64) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = IF(device, opt.vram_O) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + + + + diff --git a/stable-dreamfusion-3DPortrait/guidance/optimizer.py b/stable-dreamfusion-3DPortrait/guidance/optimizer.py new file mode 100644 index 0000000..f5bb64f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/optimizer.py @@ -0,0 +1,325 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. (default: True) + """ + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + + clip_global_grad_norm = torch.clamp( + max_grad_norm / (global_grad_norm + group['eps']), + max=1.0).item() + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + neg_pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'neg_pre_grad' not in state or group['step'] == 1: + state['neg_pre_grad'] = p.grad.clone().mul_( + -clip_global_grad_norm) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + neg_pre_grads.append(state['neg_pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + neg_pre_grads=neg_pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + + if group['foreach']: + _multi_tensor_adan(**kwargs) + else: + _single_tensor_adan(**kwargs) + + return loss + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + neg_grad_or_diff = neg_pre_grads[i] + + grad.mul_(clip_global_grad_norm) + + # for memory saving, we use `neg_grad_or_diff` + # to get some temp variable in a inplace way + neg_grad_or_diff.add_(grad) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, + alpha=1 - beta2) # diff_t + + neg_grad_or_diff.mul_(beta2).add_(grad) + exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, + neg_grad_or_diff, + value=1 - beta3) # n_t + + denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + param.div_(1 + lr * weight_decay) + + neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if len(params) == 0: + return + + torch._foreach_mul_(grads, clip_global_grad_norm) + + # for memory saving, we use `neg_pre_grads` + # to get some temp variable in a inplace way + torch._foreach_add_(neg_pre_grads, grads) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, neg_pre_grads, + alpha=1 - beta2) # diff_t + + torch._foreach_mul_(neg_pre_grads, beta2) + torch._foreach_add_(neg_pre_grads, grads) + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_(exp_avg_sqs, + neg_pre_grads, + neg_pre_grads, + value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + else: + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + torch._foreach_div_(params, 1 + lr * weight_decay) + torch._foreach_zero_(neg_pre_grads) + torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/guidance/perpneg_utils.py b/stable-dreamfusion-3DPortrait/guidance/perpneg_utils.py new file mode 100644 index 0000000..0dd5ff5 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/perpneg_utils.py @@ -0,0 +1,48 @@ +import torch + +# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm +def get_perpendicular_component(x, y): + assert x.shape == y.shape + return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y + + +def batch_get_perpendicular_component(x, y): + assert x.shape == y.shape + result = [] + for i in range(x.shape[0]): + result.append(get_perpendicular_component(x[i], y[i])) + return torch.stack(result) + + +def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size): + """ + Notes: + - weights: an array with the weights for combining the noise predictions + - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir + """ + delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64] + weights = weights.split(batch_size, dim=0) # K x [B] + # print(f"{weights[0].shape = } {weights = }") + + assert torch.all(weights[0] == 1.0) + + main_positive = delta_noise_preds[0] # [B, 4, 64, 64] + + accumulated_output = torch.zeros_like(main_positive) + for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1): + # print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n") + + idx_non_zero = torch.abs(weights[i]) > 1e-4 + + # print(f"{idx_non_zero.shape = }, {idx_non_zero = }") + # print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }") + # print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }") + # print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }") + if sum(idx_non_zero) == 0: + continue + accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero]) + + assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }" + + + return accumulated_output + main_positive \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/guidance/sd_utils.py b/stable-dreamfusion-3DPortrait/guidance/sd_utils.py new file mode 100644 index 0000000..48200fa --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/sd_utils.py @@ -0,0 +1,513 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline,AutoencoderTiny +from diffusers.utils.import_utils import is_xformers_available +from os.path import isfile +from pathlib import Path + +# suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.utils import save_image + +from torch.cuda.amp import custom_bwd, custom_fwd +try: + from .perpneg_utils import weighted_perpendicular_aggregator +except: + from perpneg_utils import weighted_perpendicular_aggregator + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + +class StableDiffusion(nn.Module): + def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98],): + super().__init__() + + self.device = device + self.sd_version = sd_version + + print(f'[INFO] loading stable diffusion...') + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.1': + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + self.precision_t = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + # pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.vae = pipe.vae + #self.vae = AutoencoderKL.from_pretrained('F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion/pretrained/vae-ft-mse-840000-ema-pruned', torch_dtype=self.precision_t).to(self.device) + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded stable diffusion!') + + @torch.no_grad() + def get_text_embeds(self, prompt): + # prompt: [str] + + inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + + return embeddings + + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + + # feature_image + (1 - weights_samples) * bcg_image + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + # + # # visualize predicted denoised image + # # The following block of code is equivalent to `predict_start_from_noise`... + # # see zero123_utils.py's version for a simpler implementation. + # alphas = self.scheduler.alphas.to(latents) + # total_timesteps = self.max_step - self.min_step + 1 + # index = total_timesteps - t.to(latents.device) - 1 + # b = len(noise_pred) + # a_t = alphas[index].reshape(b, 1, 1, 1).to(self.device) + # sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + # sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b, 1, 1, 1)).to(self.device) + # pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + # result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + # + # # visualize noisier image + # result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + # + # # TODO: also denoise all-the-way + # # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + # # print(F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False).shape, pred_rgb_512.shape) + # viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image], dim=0) + # save_image(viz_images, save_guidance_path) + + guidance_eval_utils = { + "use_perp_neg": False, + "neg_guidance_weights": None, + "text_embeddings": text_embeddings, + "t_orig": t, + "latents_noisy": latents_noisy, + "noise_pred": noise_pred, + "guidance_scale": guidance_scale, + "return_imgs_final": False, + } + + guidance_eval_out = self.guidance_eval(**guidance_eval_utils) + # decode_latents(latents_1step).permute(0, 2, 3, 1) + # "imgs_noisy": imgs_noisy, + # "imgs_1step": imgs_1step, + # "imgs_1orig": imgs_1orig, + # "imgs_final": imgs_final, + viz_images = [pred_rgb_512] + for k in guidance_eval_out: + if k.startswith("imgs_"): + viz_images.append(guidance_eval_out[k]) + viz_images = torch.cat(viz_images, dim=0) + + save_image(viz_images, save_guidance_path) + + + + + + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + @torch.no_grad() + def get_noise_pred( + self, + latents_noisy, + t, + text_embeddings, + use_perp_neg=False, + neg_guidance_weights=None, + guidance_scale=100.0, + ): + batch_size = latents_noisy.shape[0] + + if use_perp_neg: + raise NotImplementedError + else: + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_embeddings, + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return noise_pred + + + @torch.no_grad() + def guidance_eval( + self, + t_orig, + text_embeddings, + latents_noisy, + noise_pred, + use_perp_neg=False, + neg_guidance_weights=None, + guidance_scale=100.0, + return_imgs_final=False, + ): + # use only 50 timesteps, and find nearest of those to t + self.scheduler.set_timesteps(50) + self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) + max_items_eval = 4 + bs = ( + min(max_items_eval, latents_noisy.shape[0]) + if max_items_eval > 0 + else latents_noisy.shape[0] + ) # batch size + large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[:bs].unsqueeze( + -1) # sized [bs,50] > [bs,1] + idxs = torch.min(large_enough_idxs, dim=1)[1] + t = self.scheduler.timesteps_gpu[idxs] + + fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) + imgs_noisy = self.decode_latents(latents_noisy[:bs]) + + # get prev latent + latents_1step = [] + pred_1orig = [] + for b in range(bs): + step_output = self.scheduler.step( + noise_pred[b: b + 1], t[b], latents_noisy[b: b + 1], eta=1 + ) + latents_1step.append(step_output["prev_sample"]) + pred_1orig.append(step_output["pred_original_sample"]) + latents_1step = torch.cat(latents_1step) + pred_1orig = torch.cat(pred_1orig) + imgs_1step = self.decode_latents(latents_1step) + imgs_1orig = self.decode_latents(pred_1orig) + + res = { + "bs": bs, + "noise_levels": fracs, + "imgs_noisy": imgs_noisy, + "imgs_1step": imgs_1step, + "imgs_1orig": imgs_1orig, + + } + if return_imgs_final: + latents_final = [] + for b, i in enumerate(idxs): + latents = latents_1step[b: b + 1] + text_emb = ( + text_embeddings[ + [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... + ] + if use_perp_neg + else text_embeddings[[b, b + len(idxs)], ...] + ) + neg_guid = neg_guidance_weights[b: b + 1] if use_perp_neg else None + for t in self.scheduler.timesteps[i + 1:]: + # pred noise + # noise_pred = self.get_noise_pred( + # latents, t, text_emb, use_perp_neg, neg_guid,guidance_scale = guidance_scale + # ) + + latent_model_input = torch.cat([latents] * 2, dim=0) + noise_pred = self.unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_emb, + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + + # get prev latent + latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ + "prev_sample" + ] + latents_final.append(latents) + + latents_final = torch.cat(latents_final) + imgs_final = self.decode_latents(latents_final) + + res["imgs_final"] = imgs_final + + return res + def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + B = pred_rgb.shape[0] + K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * (1 + K)) + tt = torch.cat([t] * (1 + K)) + unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:] + delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1) + noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B) + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + + # visualize predicted denoised image + # The following block of code is equivalent to `predict_start_from_noise`... + # see zero123_utils.py's version for a simpler implementation. + alphas = self.alphas.to(latents) + total_timesteps = self.max_step - self.min_step + 1 + index = total_timesteps - t.to(latents.device) - 1 + b = len(noise_pred) + a_t = alphas[index].reshape(b,1,1,1).to(self.device) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) + pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + + + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) + save_image(viz_images, save_guidance_path) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + + @torch.no_grad() + def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if latents is None: + latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + def decode_latents(self, latents): + + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768] + neg_embeds = self.get_text_embeds(negative_prompts) + text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + # Text embeds -> img latents + latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + return imgs + + +if __name__ == '__main__': + + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=2) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + + +# python guidance/sd_utils.py "upper body photo of caucasian man in black clothes, night city street, bokeh" --hf_key pretrained/SG161222Realistic_Vision_V5.1_noVAE -H 512 -W 512 --seed 42 diff --git a/stable-dreamfusion-3DPortrait/guidance/sdedit.py b/stable-dreamfusion-3DPortrait/guidance/sdedit.py new file mode 100644 index 0000000..6c4b28f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/sdedit.py @@ -0,0 +1,605 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline,AutoencoderTiny +import numpy as np +from pathlib import Path +import glob +import os +# suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.utils import save_image + +import tqdm +try: + from .perpneg_utils import weighted_perpendicular_aggregator +except: + from perpneg_utils import weighted_perpendicular_aggregator + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + +class StableDiffusion(nn.Module): + def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98],): + super().__init__() + + self.device = device + self.sd_version = sd_version + + print(f'[INFO] loading stable diffusion...') + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.1': + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + self.precision_t = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + # pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.vae = pipe.vae + #self.vae = AutoencoderKL.from_pretrained('F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion/pretrained/vae-ft-mse-840000-ema-pruned', torch_dtype=self.precision_t).to(self.device) + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded stable diffusion!') + + @torch.no_grad() + def get_text_embeds(self, prompt): + # prompt: [str] + + inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + + return embeddings + + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + + # feature_image + (1 - weights_samples) * bcg_image + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + # + # # visualize predicted denoised image + # # The following block of code is equivalent to `predict_start_from_noise`... + # # see zero123_utils.py's version for a simpler implementation. + # alphas = self.scheduler.alphas.to(latents) + # total_timesteps = self.max_step - self.min_step + 1 + # index = total_timesteps - t.to(latents.device) - 1 + # b = len(noise_pred) + # a_t = alphas[index].reshape(b, 1, 1, 1).to(self.device) + # sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + # sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b, 1, 1, 1)).to(self.device) + # pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + # result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + # + # # visualize noisier image + # result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + # + # # TODO: also denoise all-the-way + # # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + # # print(F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False).shape, pred_rgb_512.shape) + # viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image], dim=0) + # save_image(viz_images, save_guidance_path) + + guidance_eval_utils = { + "use_perp_neg": False, + "neg_guidance_weights": None, + "text_embeddings": text_embeddings, + "t_orig": t, + "latents_noisy": latents_noisy, + "noise_pred": noise_pred, + "guidance_scale": guidance_scale, + "return_imgs_final": False, + } + + guidance_eval_out = self.guidance_eval(**guidance_eval_utils) + # decode_latents(latents_1step).permute(0, 2, 3, 1) + # "imgs_noisy": imgs_noisy, + # "imgs_1step": imgs_1step, + # "imgs_1orig": imgs_1orig, + # "imgs_final": imgs_final, + viz_images = [pred_rgb_512] + for k in guidance_eval_out: + if k.startswith("imgs_"): + viz_images.append(guidance_eval_out[k]) + viz_images = torch.cat(viz_images, dim=0) + + save_image(viz_images, save_guidance_path) + + + + + + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + @torch.no_grad() + def get_noise_pred( + self, + latents_noisy, + t, + text_embeddings, + use_perp_neg=False, + neg_guidance_weights=None, + guidance_scale=100.0, + ): + batch_size = latents_noisy.shape[0] + + if use_perp_neg: + raise NotImplementedError + else: + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2, dim=0) + noise_pred = self.unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_embeddings, + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return noise_pred + + + @torch.no_grad() + def guidance_eval( + self, + t_orig, + text_embeddings, + latents_noisy, + noise_pred, + use_perp_neg=False, + neg_guidance_weights=None, + guidance_scale=100.0, + return_imgs_final=False, + ): + # use only 50 timesteps, and find nearest of those to t + self.scheduler.set_timesteps(50) + self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) + max_items_eval = 4 + bs = ( + min(max_items_eval, latents_noisy.shape[0]) + if max_items_eval > 0 + else latents_noisy.shape[0] + ) # batch size + large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[:bs].unsqueeze( + -1) # sized [bs,50] > [bs,1] + idxs = torch.min(large_enough_idxs, dim=1)[1] + t = self.scheduler.timesteps_gpu[idxs] + + fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) + imgs_noisy = self.decode_latents(latents_noisy[:bs]) + + # get prev latent + latents_1step = [] + pred_1orig = [] + for b in range(bs): + step_output = self.scheduler.step( + noise_pred[b: b + 1], t[b], latents_noisy[b: b + 1], eta=1 + ) + latents_1step.append(step_output["prev_sample"]) + pred_1orig.append(step_output["pred_original_sample"]) + latents_1step = torch.cat(latents_1step) + pred_1orig = torch.cat(pred_1orig) + imgs_1step = self.decode_latents(latents_1step) + imgs_1orig = self.decode_latents(pred_1orig) + + res = { + "bs": bs, + "noise_levels": fracs, + "imgs_noisy": imgs_noisy, + "imgs_1step": imgs_1step, + "imgs_1orig": imgs_1orig, + + } + if return_imgs_final: + latents_final = [] + for b, i in enumerate(idxs): + latents = latents_1step[b: b + 1] + text_emb = ( + text_embeddings[ + [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... + ] + if use_perp_neg + else text_embeddings[[b, b + len(idxs)], ...] + ) + neg_guid = neg_guidance_weights[b: b + 1] if use_perp_neg else None + for t in self.scheduler.timesteps[i + 1:]: + # pred noise + # noise_pred = self.get_noise_pred( + # latents, t, text_emb, use_perp_neg, neg_guid,guidance_scale = guidance_scale + # ) + + latent_model_input = torch.cat([latents] * 2, dim=0) + noise_pred = self.unet( + latent_model_input, + torch.cat([t.reshape(1)] * 2).to(self.device), + encoder_hidden_states=text_emb, + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + + # get prev latent + latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ + "prev_sample" + ] + latents_final.append(latents) + + latents_final = torch.cat(latents_final) + imgs_final = self.decode_latents(latents_final) + + res["imgs_final"] = imgs_final + + return res + def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + B = pred_rgb.shape[0] + K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * (1 + K)) + tt = torch.cat([t] * (1 + K)) + unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:] + delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1) + noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B) + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + + # visualize predicted denoised image + # The following block of code is equivalent to `predict_start_from_noise`... + # see zero123_utils.py's version for a simpler implementation. + alphas = self.alphas.to(latents) + total_timesteps = self.max_step - self.min_step + 1 + index = total_timesteps - t.to(latents.device) - 1 + b = len(noise_pred) + a_t = alphas[index].reshape(b,1,1,1).to(self.device) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) + pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + + + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) + save_image(viz_images, save_guidance_path) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + + @torch.no_grad() + def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if latents is None: + latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + + + def decode_latents(self, latents): + + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768] + neg_embeds = self.get_text_embeds(negative_prompts) + text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + # Text embeds -> img latents + latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + return imgs + + + def denoise_latents(self, text_embeddings, start_t,num_inference_steps=50, guidance_scale=7.5, latents=None): + + + self.scheduler.set_timesteps(num_inference_steps) + for t in tqdm.tqdm(self.scheduler.timesteps): + if t>start_t: + continue + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + + def sdedit(self, data_dir, height=512, width=512, num_inference_steps=50,test_data_dir = None, guidance_scale=7.5): + + + noise_level = 200 + res_dir = data_dir + origin_data_dir = os.path.join(res_dir, 'data') + if not os.path.exists(origin_data_dir): + print('no data dir') + return + + update_data_dir = os.path.join(res_dir, 'update_data') + os.makedirs(update_data_dir, exist_ok=True) + + if len(glob.glob(origin_data_dir + '/*.png')) == len(glob.glob(update_data_dir + '/*.png')): + print('already done') + return + print('gen data for ', res_dir) + + name = os.path.basename(res_dir) + + prompt_path = os.path.join(test_data_dir, f'{name}/prompt.txt') + if os.path.exists(prompt_path): + with open(prompt_path, 'r') as f: + prompts = f.read().strip() + else: + raise ValueError('prompt.txt not exists') + + if isinstance(prompts, str): + prompts = [prompts] + # Prompts -> text embeds + pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768] + neg_embeds = self.get_text_embeds('worst quality, low quality, jpeg artifacts, blurry') + text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + for image_path in glob.glob(origin_data_dir + '/*.png'): + image = PIL.Image.open(image_path).convert('RGB') + image = np.array(image) + + origin_img = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(self.device) # --> 0,1 + origin_img = origin_img / 255.0 + + latents = self.encode_imgs(origin_img) + + t = torch.tensor([noise_level], dtype=torch.long, + device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + latents = self.denoise_latents(text_embeds, noise_level, num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, latents=latents_noisy) + + # Img latents -> imgs + img = self.decode_latents(latents) # [1, 3, 512, 512] + # Img to Numpy + img = img.detach().cpu().permute(0, 2, 3, 1).numpy() + img = (img * 255).round().astype('uint8')[0] + + PIL.Image.fromarray(img).save(os.path.join(update_data_dir, os.path.basename(image_path))) + + + +if __name__ == '__main__': + + import argparse + import matplotlib.pyplot as plt + import PIL + + parser = argparse.ArgumentParser() + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + + + parser.add_argument('--data_dir', type=str,help='Network pickle filename', required=True) + parser.add_argument('--test_data_dir', type=str,help='test_data_dir', required=True) + + + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=2) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + + + sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) + + imgs = sd.sdedit(opt.data_dir,opt.H, opt.W, opt.steps,opt.test_data_dir) + + + +# \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/guidance/test_taesd.py b/stable-dreamfusion-3DPortrait/guidance/test_taesd.py new file mode 100644 index 0000000..7e2aef9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/test_taesd.py @@ -0,0 +1,626 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline,AutoencoderTiny +from diffusers.utils.import_utils import is_xformers_available +from os.path import isfile +from pathlib import Path + +# suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.utils import save_image + +from torch.cuda.amp import custom_bwd, custom_fwd +try: + from .perpneg_utils import weighted_perpendicular_aggregator +except: + from perpneg_utils import weighted_perpendicular_aggregator +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + + +from dataclasses import dataclass +from diffusers.utils.outputs import BaseOutput +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.vae import DecoderOutput,DecoderTiny, EncoderTiny +from typing import Tuple, Union + +@dataclass +class TaesdOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: torch.Tensor + + +class Taesd(nn.Module): + def __init__(self,device, fp16 + ): + super().__init__() + self.device = device + self.precision_t = torch.float16 if fp16 else torch.float32 + #if init_ckpt + + in_channels = 3 + out_channels = 3 + encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64) + decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64) + act_fn: str = "relu" + latent_channels: int = 4 + upsampling_scaling_factor: int = 2 + num_encoder_blocks: Tuple[int] = (1, 3, 3, 3) + num_decoder_blocks: Tuple[int] = (3, 3, 3, 1) + latent_magnitude: int = 3 + latent_shift: float = 0.5 + force_upcast: float = False + scaling_factor: float = 1.0 + + + self.decoder = DecoderTiny( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + ).to(device).to(self.precision_t) + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ).to(device).to(self.precision_t) + + print("Loading pretrained model") + vae = AutoencoderTiny.from_pretrained("pretrained/taesd", torch_dtype=self.precision_t).to(device) + self.decoder.load_state_dict(vae.decoder.state_dict()) + self.encoder.load_state_dict(vae.encoder.state_dict()) + del vae + + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2 ** out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + + + + def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + # scale of encoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_sample_min_size + + # number of pixels to blend and to traverse between tile + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] + tile = self.encoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = blend_mask_i * blend_mask_j + tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # scale of decoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_latent_min_size + + # number of pixels to blend and to traverse between tiles + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] + tile = self.decoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[TaesdOutput, Tuple[torch.FloatTensor]]: + if self.use_slicing and x.shape[0] > 1: + output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)] + output = torch.cat(output) + else: + output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) + + if not return_dict: + return (output,) + + return TaesdOutput(latents=output) + + @apply_forward_hook + def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + print('decode in Taesd') + if self.use_slicing and x.shape[0] > 1: + output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] + output = torch.cat(output) + else: + output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) + + if not return_dict: + return (output,) + + return DecoderOutput(sample=output) + + + def decode_latents(self, latents): + ''' + + :param latents: [B, 4, H, W], + :return: imgs [B, 3, H, W], in [0, 1] + ''' + latents = 1 / self.scaling_factor * latents + + imgs = self.decode(latents).sample + + + + + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + ''' + + :param imgs: [B, 3, H, W], in [0, 1] + :return: latents: [B, 4, H, W], + ''' + + + imgs = 2 * imgs - 1 # to [-1, 1] + + posterior = self.encode(imgs).latent_dist + + latents = posterior.sample() * self.scaling_factor + + return latents + + +class StableDiffusion(nn.Module): + def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98]): + super().__init__() + + self.device = device + self.sd_version = sd_version + + print(f'[INFO] loading stable diffusion...') + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.1': + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + self.precision_t = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + # pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.taesd = Taesd(device, fp16) + self.vae = AutoencoderKL.from_pretrained('./pretrained/vae-ft-mse-840000-ema-pruned', torch_dtype=self.precision_t).to(self.device) + + + + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded stable diffusion!') + + @torch.no_grad() + def get_text_embeds(self, prompt): + # prompt: [str] + + inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + + return embeddings + + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + + # feature_image + (1 - weights_samples) * bcg_image + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + + # visualize predicted denoised image + # The following block of code is equivalent to `predict_start_from_noise`... + # see zero123_utils.py's version for a simpler implementation. + alphas = self.scheduler.alphas.to(latents) + total_timesteps = self.max_step - self.min_step + 1 + index = total_timesteps - t.to(latents.device) - 1 + b = len(noise_pred) + a_t = alphas[index].reshape(b,1,1,1).to(self.device) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) + pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + + # TODO: also denoise all-the-way + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + #print(F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False).shape, pred_rgb_512.shape) + viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) + save_image(viz_images, save_guidance_path) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + + def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, + save_guidance_path:Path=None): + + + B = pred_rgb.shape[0] + K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * (1 + K)) + tt = torch.cat([t] * (1 + K)) + unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:] + delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1) + noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B) + + # import kiui + # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device) + # latents_tmp = latents_tmp.detach() + # kiui.lo(latents_tmp) + # self.scheduler.set_timesteps(30) + # for i, t in enumerate(self.scheduler.timesteps): + # latent_model_input = torch.cat([latents_tmp] * 3) + # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond) + # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample'] + # imgs = self.decode_latents(latents_tmp) + # kiui.vis.plot_image(imgs) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_512 = self.decode_latents(latents) + + # visualize predicted denoised image + # The following block of code is equivalent to `predict_start_from_noise`... + # see zero123_utils.py's version for a simpler implementation. + alphas = self.scheduler.alphas.to(latents) + total_timesteps = self.max_step - self.min_step + 1 + index = total_timesteps - t.to(latents.device) - 1 + b = len(noise_pred) + a_t = alphas[index].reshape(b,1,1,1).to(self.device) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) + pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 + result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) + + + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) + save_image(viz_images, save_guidance_path) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + + @torch.no_grad() + def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if latents is None: + latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + def decode_latents(self, latents): + + latents = 1 / self.vae.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + + + + + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + + latents = posterior.sample() * self.vae.scaling_factor + + return latents + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768] + neg_embeds = self.get_text_embeds(negative_prompts) + text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + # Text embeds -> img latents + latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + + imgs2 = self.taesd.decode_latents(latents) + + imgs = torch.cat([imgs,imgs2],dim=2) + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + + + + + return imgs + + +if __name__ == '__main__': + + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + parser.add_argument('--use_tiny_vae', action='store_true', help="use tiny vae") + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=2) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + + +# python guidance/test_taesd.py "upper body photo of caucasian man in black clothes, night city street, bokeh" --hf_key pretrained/SG161222Realistic_Vision_V5.1_noVAE -H 512 -W 512 --seed 42 --use_tiny_vae + diff --git a/stable-dreamfusion-3DPortrait/guidance/zero123_utils.py b/stable-dreamfusion-3DPortrait/guidance/zero123_utils.py new file mode 100644 index 0000000..cc41161 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/guidance/zero123_utils.py @@ -0,0 +1,320 @@ +import math +import numpy as np +from omegaconf import OmegaConf +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from torchvision.utils import save_image + +from diffusers import DDIMScheduler + +import sys +from os import path +sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) + +from ldm.util import instantiate_from_config + + +# load model +def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False): + + pl_sd = torch.load(ckpt, map_location='cpu') + + if 'global_step' in pl_sd and verbose: + print(f'[INFO] Global Step: {pl_sd["global_step"]}') + + sd = pl_sd['state_dict'] + + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print('[INFO] missing keys: \n', m) + if len(u) > 0 and verbose: + print('[INFO] unexpected keys: \n', u) + + # manually load ema and delete it to save GPU memory + if model.use_ema: + if verbose: + print('[INFO] loading EMA...') + model.model_ema.copy_to(model.model) + del model.model_ema + + if vram_O: + # we don't need decoder + del model.first_stage_model.decoder + + torch.cuda.empty_cache() + + model.eval().to(device) + + return model + +class Zero123(nn.Module): + def __init__(self, device, fp16, + config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', + ckpt='./pretrained/zero123/zero123-xl.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None): + super().__init__() + + self.device = device + self.fp16 = fp16 + self.vram_O = vram_O + self.t_range = t_range + self.opt = opt + + self.config = OmegaConf.load(config) + # TODO: seems it cannot load into fp16... + self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O) + + # timesteps: use diffuser for convenience... hope it's alright. + self.num_train_timesteps = self.config.model.params.timesteps + + self.scheduler = DDIMScheduler( + self.num_train_timesteps, + self.config.model.params.linear_start, + self.config.model.params.linear_end, + beta_schedule='scaled_linear', + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + @torch.no_grad() + def get_img_embeds(self, x): + # x: image tensor [B, 3, 256, 256] in [0, 1] + x = x * 2 - 1 + c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1) + v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x] + return c, v + + def angle_between(self, sph_v1, sph_v2): + def sph2cart(sv): + r, theta, phi = sv[0], sv[1], sv[2] + return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)]) + def unit_vector(v): + return v / torch.linalg.norm(v) + def angle_between_2_sph(sv1, sv2): + v1, v2 = sph2cart(sv1), sph2cart(sv2) + v1_u, v2_u = unit_vector(v1), unit_vector(v2) + return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0)) + angles = torch.empty(len(sph_v1), len(sph_v2)) + for i, sv1 in enumerate(sph_v1): + for j, sv2 in enumerate(sph_v2): + angles[i][j] = angle_between_2_sph(sv1, sv2) + return angles + + def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None): + # pred_rgb: tensor [1, 3, H, W] in [0, 1] + + # adjust SDS scale based on how far the novel view is from the known view + ref_radii = embeddings['ref_radii'] + ref_polars = embeddings['ref_polars'] + ref_azimuths = embeddings['ref_azimuths'] + v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default + v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1) + angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device) + if self.opt.zero123_grad_scale == 'angle': + grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0 + elif self.opt.zero123_grad_scale == 'None': + grad_scale = 1.0 # claforte: I think this might converge faster...? + else: + assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}' + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256) + + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # Set weights acc to closeness in angle + if len(ref_azimuths) > 1: + inv_angles = 1/angles + inv_angles[inv_angles > 100] = 100 + inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0] + inv_angles[inv_angles < 0.1] = 0 + else: + inv_angles = torch.tensor([1.]).to(self.device) + + # Multiply closeness-weight by user-given weights + zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles + zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0] + zero123_ws[zero123_ws < 0.1] = 0 + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + + noise_preds = [] + # Loop through each ref image + for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T, + embeddings['c_crossattn'], embeddings['c_concat'], + ref_polars, ref_azimuths, ref_radii): + # polar,azimuth,radius are all actually delta wrt default + p = polar + ref_polars[0] - ref_polar + a = azimuth + ref_azimuths[0] - ref_azimuth + a[a > 180] -= 360 # range in [-180, 180] + r = radius + ref_radii[0] - ref_radius + # T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r]) + # T = T[None, None, :].to(self.device) + T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :] + cond = {} + clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1)) + cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] + cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)] + noise_pred = self.model.apply_model(x_in, t_in, cond) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_preds.append(zero123_w[:, None, None, None] * noise_pred) + + noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None] + + w = (1 - self.alphas[t]) + grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # import kiui + # if not as_latent: + # kiui.vis.plot_image(pred_rgb_256) + # kiui.vis.plot_matrix(latents) + # kiui.vis.plot_matrix(grad) + + # import kiui + # latents = torch.randn((1, 4, 32, 32), device=self.device) + # kiui.lo(latents) + # self.scheduler.set_timesteps(30) + # with torch.no_grad(): + # for i, t in enumerate(self.scheduler.timesteps): + # x_in = torch.cat([latents] * 2) + # t_in = torch.cat([t.view(1)] * 2).to(self.device) + + # noise_pred = self.model.apply_model(x_in, t_in, cond) + # noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + # noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond) + + # latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + # imgs = self.decode_latents(latents) + # print(polar, azimuth, radius) + # kiui.vis.plot_image(pred_rgb_256, imgs) + + if save_guidance_path: + with torch.no_grad(): + if as_latent: + pred_rgb_256 = self.decode_latents(latents) # claforte: test! + + # visualize predicted denoised image + result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred)) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy) + + # TODO: also denoise all-the-way + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1) + save_image(viz_images, save_guidance_path) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + + return loss + + # verification + @torch.no_grad() + def __call__(self, + image, # image tensor [1, 3, H, W] in [0, 1] + polar=0, azimuth=0, radius=0, # new view params + scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params + c_crossattn=None, c_concat=None, post_process=True, + ): + + if c_crossattn is None: + embeddings = self.get_img_embeds(image) + + T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius]) + T = T[None, None, :].to(self.device) + + cond = {} + clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1)) + cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] + cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)] + + # produce latents loop + latents = torch.randn((1, 4, h // 8, w // 8), device=self.device) + self.scheduler.set_timesteps(ddim_steps) + + for i, t in enumerate(self.scheduler.timesteps): + x_in = torch.cat([latents] * 2) + t_in = torch.cat([t.view(1)] * 2).to(self.device) + + noise_pred = self.model.apply_model(x_in, t_in, cond) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample'] + + imgs = self.decode_latents(latents) + imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs + + return imgs + + def decode_latents(self, latents): + # zs: [B, 4, 32, 32] Latent space image + # with self.model.ema_scope(): + imgs = self.model.decode_first_stage(latents) + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs # [B, 3, 256, 256] RGB space image + + def encode_imgs(self, imgs): + # imgs: [B, 3, 256, 256] RGB space image + # with self.model.ema_scope(): + imgs = imgs * 2 - 1 + latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0) + return latents # [B, 4, 32, 32] Latent space image + + +if __name__ == '__main__': + import cv2 + import argparse + import numpy as np + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str) + parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32 + + parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]') + parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') + parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') + + opt = parser.parse_args() + + device = torch.device('cuda') + + print(f'[INFO] loading image from {opt.input} ...') + image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) + image = image.astype(np.float32) / 255.0 + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) + + print(f'[INFO] loading model ...') + zero123 = Zero123(device, opt.fp16, opt=opt) + + print(f'[INFO] running model ...') + outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius) + plt.imshow(outputs[0]) + plt.show() \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/extras.py b/stable-dreamfusion-3DPortrait/ldm/extras.py new file mode 100644 index 0000000..62e654b --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/extras.py @@ -0,0 +1,77 @@ +from pathlib import Path +from omegaconf import OmegaConf +import torch +from ldm.util import instantiate_from_config +import logging +from contextlib import contextmanager + +from contextlib import contextmanager +import logging + +@contextmanager +def all_logging_disabled(highest_level=logging.CRITICAL): + """ + A context manager that will prevent any logging messages + triggered during the body from being processed. + + :param highest_level: the maximum logging level in use. + This would only need to be changed if a custom level greater than CRITICAL + is defined. + + https://gist.github.com/simon-weber/7853144 + """ + # two kind-of hacks here: + # * can't get the highest logging level in effect => delegate to the user + # * can't get the current module-level override => use an undocumented + # (but non-private!) interface + + previous_level = logging.root.manager.disable + + logging.disable(highest_level) + + try: + yield + finally: + logging.disable(previous_level) + +def load_training_dir(train_dir, device, epoch="last"): + """Load a checkpoint and config from training directory""" + train_dir = Path(train_dir) + ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) + assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" + config = list(train_dir.rglob(f"*-project.yaml")) + assert len(ckpt) > 0, f"didn't find any config in {train_dir}" + if len(config) > 1: + print(f"found {len(config)} matching config files") + config = sorted(config)[-1] + print(f"selecting {config}") + else: + config = config[0] + + + config = OmegaConf.load(config) + return load_model_from_config(config, ckpt[0], device) + +def load_model_from_config(config, ckpt, device="cpu", verbose=False): + """Loads a model from config and a ckpt + if config is a path will use omegaconf to load + """ + if isinstance(config, (str, Path)): + config = OmegaConf.load(config) + + with all_logging_disabled(): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + model.to(device) + model.eval() + model.cond_stage_model.device = device + return model \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/guidance.py b/stable-dreamfusion-3DPortrait/ldm/guidance.py new file mode 100644 index 0000000..53d1a2a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/guidance.py @@ -0,0 +1,96 @@ +from typing import List, Tuple +from scipy import interpolate +import numpy as np +import torch +import matplotlib.pyplot as plt +from IPython.display import clear_output +import abc + + +class GuideModel(torch.nn.Module, abc.ABC): + def __init__(self) -> None: + super().__init__() + + @abc.abstractmethod + def preprocess(self, x_img): + pass + + @abc.abstractmethod + def compute_loss(self, inp): + pass + + +class Guider(torch.nn.Module): + def __init__(self, sampler, guide_model, scale=1.0, verbose=False): + """Apply classifier guidance + + Specify a guidance scale as either a scalar + Or a schedule as a list of tuples t = 0->1 and scale, e.g. + [(0, 10), (0.5, 20), (1, 50)] + """ + super().__init__() + self.sampler = sampler + self.index = 0 + self.show = verbose + self.guide_model = guide_model + self.history = [] + + if isinstance(scale, (Tuple, List)): + times = np.array([x[0] for x in scale]) + values = np.array([x[1] for x in scale]) + self.scale_schedule = {"times": times, "values": values} + else: + self.scale_schedule = float(scale) + + self.ddim_timesteps = sampler.ddim_timesteps + self.ddpm_num_timesteps = sampler.ddpm_num_timesteps + + + def get_scales(self): + if isinstance(self.scale_schedule, float): + return len(self.ddim_timesteps)*[self.scale_schedule] + + interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) + fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps + return interpolater(fractional_steps) + + def modify_score(self, model, e_t, x, t, c): + + # TODO look up index by t + scale = self.get_scales()[self.index] + + if (scale == 0): + return e_t + + sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) + x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) + + inp = self.guide_model.preprocess(x_img) + loss = self.guide_model.compute_loss(inp) + grads = torch.autograd.grad(loss.sum(), x_in)[0] + correction = grads * scale + + if self.show: + clear_output(wait=True) + print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) + self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) + plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) + plt.axis('off') + plt.show() + plt.imshow(correction[0][0].detach().cpu()) + plt.axis('off') + plt.show() + + + e_t_mod = e_t - sqrt_1ma*correction + if self.show: + fig, axs = plt.subplots(1, 3) + axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) + plt.show() + self.index += 1 + return e_t_mod \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/lr_scheduler.py b/stable-dreamfusion-3DPortrait/ldm/lr_scheduler.py new file mode 100644 index 0000000..be39da9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/stable-dreamfusion-3DPortrait/ldm/models/autoencoder.py b/stable-dreamfusion-3DPortrait/ldm/models/autoencoder.py new file mode 100644 index 0000000..6a9c4f4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/__init__.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/classifier.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000..67e98b9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddim.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000..0683d16 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddim.py @@ -0,0 +1,328 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from einops import rearrange + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def to(self, device): + """Same as to in torch module + Don't really underestand why this isn't a module in the first place""" + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + new_v = getattr(self, k).to(device) + setattr(self, k, new_v) + + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + t_start=-1): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + timesteps = timesteps[:t_start] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: + img = callback(i, img, pred_x0) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + print(t, sqrt_one_minus_at, a_t) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddpm.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000..3fcb7ad --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1994 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.attention import CrossAttention + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape)==len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1-p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + unet_trainable=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.unet_trainable = unet_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + + # construct linear projection layer for concatenating image CLIP embedding and RT + self.cc_projection = nn.Linear(772, 768) + nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) + nn.init.zeros_(list(self.cc_projection.parameters())[1]) + self.cc_projection.requires_grad_(True) + + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + T = batch['T'].to(memory_format=torch.contiguous_format).float() + + if bs is not None: + x = x[:bs] + T = T[:bs].to(self.device) + + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key).to(self.device) + if bs is not None: + xc = xc[:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + null_prompt = self.get_learned_conditioning([""]) + + # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] + # print('=========== xc shape ===========', xc.shape) + with torch.enable_grad(): + clip_emb = self.get_learned_conditioning(xc).detach() + null_prompt = self.get_learned_conditioning([""]).detach() + cond["c_crossattn"] = [self.cc_projection(torch.cat([torch.where(prompt_mask, null_prompt, clip_emb), T[:, None, :]], dim=-1))] + cond["c_concat"] = [input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach()] + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + # @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # @torch.no_grad() # wasted two hours to find this bug... why no grad here! + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + # if self.cond_stage_trainable: + # c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + adapted_cond = self.get_learned_conditioning(adapted_cond) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + cond = {} + cond["c_crossattn"] = [c] + cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)] + return cond + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1]) + # uc = torch.zeros_like(c) + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = [] + if self.unet_trainable == "attn": + print("Training only unet attention layers") + for n, m in self.model.named_modules(): + if isinstance(m, CrossAttention) and n.endswith('attn2'): + params.extend(m.parameters()) + if self.unet_trainable == "conv_in": + print("Training only unet input conv layers") + params = list(self.model.diffusion_model.input_blocks[0][0].parameters()) + elif self.unet_trainable is True or self.unet_trainable == "all": + print("Training the full unet") + params = list(self.model.parameters()) + else: + raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}") + + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + if self.cc_projection is not None: + params = params + list(self.cc_projection.parameters()) + print('========== optimizing for cc projection weight ==========') + + opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr}, + {"params": self.cc_projection.parameters(), "lr": 10. * lr}], lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + # c_crossattn dimension: torch.Size([8, 1, 768]) 1 + # cc dimension: torch.Size([8, 1, 768] + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + #import pudb; pu.db + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentInpaintDiffusion(LatentDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + def __init__(self, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + log["masked_image"] = rearrange(batch["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + + +class SimpleUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + all_conds = {"c_concat": [zx], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log + +class MultiCatFrameDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + n = 2 + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + cat_conds = batch[self.low_scale_key][:bs] + cats = [] + for i in range(n): + x_low = cat_conds[:,:,:,3*i:3*(i+1)] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + cats.append(zx) + + all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/plms.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/plms.py new file mode 100644 index 0000000..080edee --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/plms.py @@ -0,0 +1,259 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/stable-dreamfusion-3DPortrait/ldm/models/diffusion/sampling_util.py b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000..a0ae00f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,50 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def renorm_thresholding(x0, value): + # renorm + pred_max = x0.max() + pred_min = x0.min() + pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 + pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + value, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range + return pred_x0 + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/attention.py b/stable-dreamfusion-3DPortrait/ldm/modules/attention.py new file mode 100644 index 0000000..124effb --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/attention.py @@ -0,0 +1,266 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/__init__.py b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/model.py b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000..533e589 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/openaimodel.py b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000..09f0ae1 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,996 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/util.py b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000..a952e6c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/distributions/__init__.py b/stable-dreamfusion-3DPortrait/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/distributions/distributions.py b/stable-dreamfusion-3DPortrait/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000..f2b8ef9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/ema.py b/stable-dreamfusion-3DPortrait/ldm/modules/ema.py new file mode 100644 index 0000000..c8c75af --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/encoders/__init__.py b/stable-dreamfusion-3DPortrait/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/encoders/modules.py b/stable-dreamfusion-3DPortrait/ldm/modules/encoders/modules.py new file mode 100644 index 0000000..b1afccf --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/encoders/modules.py @@ -0,0 +1,550 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from ldm.util import default +import clip + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + +class FaceClipEncoder(AbstractEncoder): + def __init__(self, augment=True, retreival_key=None): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + self.augment = augment + self.retreival_key = retreival_key + + def forward(self, img): + encodings = [] + with torch.no_grad(): + x_offset = 125 + if self.retreival_key: + # Assumes retrieved image are packed into the second half of channels + face = img[:,3:,190:440,x_offset:(512-x_offset)] + other = img[:,:3,...].clone() + else: + face = img[:,:,190:440,x_offset:(512-x_offset)] + other = img.clone() + + if self.augment: + face = K.RandomHorizontalFlip()(face) + + other[:,:,190:440,x_offset:(512-x_offset)] *= 0 + encodings = [ + self.encoder.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class FaceIdClipEncoder(AbstractEncoder): + def __init__(self): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + for p in self.encoder.parameters(): + p.requires_grad = False + self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True) + + def forward(self, img): + encodings = [] + with torch.no_grad(): + face = kornia.geometry.resize(img, (256, 256), + interpolation='bilinear', align_corners=True) + + other = img.clone() + other[:,:,184:452,122:396] *= 0 + encodings = [ + self.id.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +from ldm.thirdp.psp.id_loss import IDFeatures +import kornia.augmentation as K + +class FrozenFaceEncoder(AbstractEncoder): + def __init__(self, model_path, augment=False): + super().__init__() + self.loss_fn = IDFeatures(model_path) + # face encoder is frozen + for p in self.loss_fn.parameters(): + p.requires_grad = False + # Mapper is trainable + self.mapper = torch.nn.Linear(512, 768) + p = 0.25 + if augment: + self.augment = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomEqualize(p=p), + # K.RandomPlanckianJitter(p=p), + # K.RandomPlasmaBrightness(p=p), + # K.RandomPlasmaContrast(p=p), + # K.ColorJiggle(0.02, 0.2, 0.2, p=p), + ) + else: + self.augment = False + + def forward(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 1, 768), device=self.mapper.weight.device) + + if self.augment is not None: + # Transforms require 0-1 + img = self.augment((img + 1)/2) + img = 2*img - 1 + + feat = self.loss_fn(img, crop=True) + feat = self.mapper(feat.unsqueeze(1)) + return feat + + def encode(self, img): + return self(img) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +import torch.nn.functional as F +from transformers import CLIPVisionModel +class ClipImageProjector(AbstractEncoder): + """ + Uses the CLIP image encoder. + """ + def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.model = CLIPVisionModel.from_pretrained(version) + self.model.train() + self.max_length = max_length # TODO: typical value? + self.antialias = True + self.mapper = torch.nn.Linear(1024, 768) + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + null_cond = self.get_null_cond(version, max_length) + self.register_buffer('null_cond', null_cond) + + @torch.no_grad() + def get_null_cond(self, version, max_length): + device = self.mean.device + embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + null_cond = embedder([""]) + return null_cond + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + if isinstance(x, list): + return self.null_cond + # x is assumed to be in range [-1,1] + x = self.preprocess(x) + outputs = self.model(pixel_values=x) + last_hidden_state = outputs.last_hidden_state + last_hidden_state = self.mapper(last_hidden_state) + return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0]) + + def encode(self, im): + return self(im) + +class ProjectedFrozenCLIPEmbedder(AbstractEncoder): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + self.projection = torch.nn.Linear(768, 768) + + def forward(self, text): + z = self.embedder(text) + return self.projection(z) + + def encode(self, text): + return self(text) + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, 768, device=device) + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) + +from torchvision import transforms +import random + +class FrozenCLIPImageMutliEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=True, + max_crops=5, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.max_crops = max_crops + + def preprocess(self, x): + + # Expects inputs in the range -1, 1 + randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) + max_crops = self.max_crops + patches = [] + crops = [randcrop(x) for _ in range(max_crops)] + patches.extend(crops) + x = torch.cat(patches, dim=0) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, self.max_crops, 768, device=device) + batch_tokens = [] + for im in x: + patches = self.preprocess(im.unsqueeze(0)) + tokens = self.model.encode_image(patches).float() + for t in tokens: + if random.random() < 0.1: + t *= 0 + batch_tokens.append(tokens.unsqueeze(0)) + + return torch.cat(batch_tokens, dim=0) + + def encode(self, im): + return self(im) + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like + + +class LowScaleEncoder(nn.Module): + def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, + scale_factor=1.0): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, + linear_end=linear_end) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + z = self.model.encode(x).sample() + z = z * self.scale_factor + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +if __name__ == "__main__": + from ldm.util import count_params + sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] + model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + model = FrozenCLIPEmbedder().cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + print("done.") diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/adm_evaluator.py b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/adm_evaluator.py new file mode 100644 index 0000000..508cddf --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/adm_evaluator.py @@ -0,0 +1,676 @@ +import argparse +import io +import os +import random +import warnings +import zipfile +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool +from typing import Iterable, Optional, Tuple +import yaml + +import numpy as np +import requests +import tensorflow.compat.v1 as tf +from scipy import linalg +from tqdm.auto import tqdm + +INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" +INCEPTION_V3_PATH = "classify_image_graph_def.pb" + +FID_POOL_NAME = "pool_3:0" +FID_SPATIAL_NAME = "mixed_6/conv:0" + +REQUIREMENTS = f"This script has the following requirements: \n" \ + 'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ref_batch", help="path to reference batch npz file") + parser.add_argument("--sample_batch", help="path to sample batch npz file") + args = parser.parse_args() + + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + evaluator = Evaluator(tf.Session(config=config)) + + print("warming up TensorFlow...") + # This will cause TF to print a bunch of verbose stuff now rather + # than after the next print(), to help prevent confusion. + evaluator.warmup() + + print("computing reference batch activations...") + ref_acts = evaluator.read_activations(args.ref_batch) + print("computing/reading reference batch statistics...") + ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) + + print("computing sample batch activations...") + sample_acts = evaluator.read_activations(args.sample_batch) + print("computing/reading sample batch statistics...") + sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) + + print("Computing evaluations...") + is_ = evaluator.compute_inception_score(sample_acts[0]) + print("Inception Score:", is_) + fid = sample_stats.frechet_distance(ref_stats) + print("FID:", fid) + sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial) + print("sFID:", sfid) + prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) + print("Precision:", prec) + print("Recall:", recall) + + savepath = '/'.join(args.sample_batch.split('/')[:-1]) + results_file = os.path.join(savepath,'evaluation_metrics.yaml') + print(f'Saving evaluation results to "{results_file}"') + + results = { + 'IS': is_, + 'FID': fid, + 'sFID': sfid, + 'Precision:':prec, + 'Recall': recall + } + + with open(results_file, 'w') as f: + yaml.dump(results, f, default_flow_style=False) + +class InvalidFIDException(Exception): + pass + + +class FIDStatistics: + def __init__(self, mu: np.ndarray, sigma: np.ndarray): + self.mu = mu + self.sigma = sigma + + def frechet_distance(self, other, eps=1e-6): + """ + Compute the Frechet distance between two sets of statistics. + """ + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 + mu1, sigma1 = self.mu, self.sigma + mu2, sigma2 = other.mu, other.sigma + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" + assert ( + sigma1.shape == sigma2.shape + ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; adding %s to diagonal of cov estimates" + % eps + ) + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +class Evaluator: + def __init__( + self, + session, + batch_size=64, + softmax_batch_size=512, + ): + self.sess = session + self.batch_size = batch_size + self.softmax_batch_size = softmax_batch_size + self.manifold_estimator = ManifoldEstimator(session) + with self.sess.graph.as_default(): + self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) + self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) + self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) + self.softmax = _create_softmax_graph(self.softmax_input) + + def warmup(self): + self.compute_activations(np.zeros([1, 8, 64, 64, 3])) + + def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: + with open_npz_array(npz_path, "arr_0") as reader: + return self.compute_activations(reader.read_batches(self.batch_size)) + + def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute image features for downstream evals. + + :param batches: a iterator over NHWC numpy arrays in [0, 255]. + :return: a tuple of numpy arrays of shape [N x X], where X is a feature + dimension. The tuple is (pool_3, spatial). + """ + preds = [] + spatial_preds = [] + it = batches if silent else tqdm(batches) + for batch in it: + batch = batch.astype(np.float32) + pred, spatial_pred = self.sess.run( + [self.pool_features, self.spatial_features], {self.image_input: batch} + ) + preds.append(pred.reshape([pred.shape[0], -1])) + spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) + return ( + np.concatenate(preds, axis=0), + np.concatenate(spatial_preds, axis=0), + ) + + def read_statistics( + self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[FIDStatistics, FIDStatistics]: + obj = np.load(npz_path) + if "mu" in list(obj.keys()): + return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( + obj["mu_s"], obj["sigma_s"] + ) + return tuple(self.compute_statistics(x) for x in activations) + + def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return FIDStatistics(mu, sigma) + + def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: + softmax_out = [] + for i in range(0, len(activations), self.softmax_batch_size): + acts = activations[i : i + self.softmax_batch_size] + softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) + preds = np.concatenate(softmax_out, axis=0) + # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 + scores = [] + for i in range(0, len(preds), split_size): + part = preds[i : i + split_size] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)) + + def compute_prec_recall( + self, activations_ref: np.ndarray, activations_sample: np.ndarray + ) -> Tuple[float, float]: + radii_1 = self.manifold_estimator.manifold_radii(activations_ref) + radii_2 = self.manifold_estimator.manifold_radii(activations_sample) + pr = self.manifold_estimator.evaluate_pr( + activations_ref, radii_1, activations_sample, radii_2 + ) + return (float(pr[0][0]), float(pr[1][0])) + + +class ManifoldEstimator: + """ + A helper for comparing manifolds of feature vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 + """ + + def __init__( + self, + session, + row_batch_size=10000, + col_batch_size=10000, + nhood_sizes=(3,), + clamp_to_percentile=None, + eps=1e-5, + ): + """ + Estimate the manifold of given feature vectors. + + :param session: the TensorFlow session. + :param row_batch_size: row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + :param col_batch_size: column batch size to compute pairwise distances. + :param nhood_sizes: number of neighbors used to estimate the manifold. + :param clamp_to_percentile: prune hyperspheres that have radius larger than + the given percentile. + :param eps: small number for numerical stability. + """ + self.distance_block = DistanceBlock(session) + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.clamp_to_percentile = clamp_to_percentile + self.eps = eps + + def warmup(self): + feats, radii = ( + np.zeros([1, 2048], dtype=np.float32), + np.zeros([1, 1], dtype=np.float32), + ) + self.evaluate_pr(feats, radii, feats, radii) + + def manifold_radii(self, features: np.ndarray) -> np.ndarray: + num_images = len(features) + + # Estimate manifold of features by calculating distances to k-NN of each sample. + radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + for begin1 in range(0, num_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(row_batch, col_batch) + + # Find the k-nearest neighbor from the current batch. + radii[begin1:end1, :] = np.concatenate( + [ + x[:, self.nhood_sizes] + for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) + ], + axis=0, + ) + + if self.clamp_to_percentile is not None: + max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) + radii[radii > max_distances] = 0 + return radii + + def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): + """ + Evaluate if new feature vectors are at the manifold. + """ + num_eval_images = eval_features.shape[0] + num_ref_images = radii.shape[0] + distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = features[begin2:end2] + + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) + + max_realism_score[begin1:end1] = np.max( + radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 + ) + nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) + + return { + "fraction": float(np.mean(batch_predictions)), + "batch_predictions": batch_predictions, + "max_realisim_score": max_realism_score, + "nearest_indices": nearest_indices, + } + + def evaluate_pr( + self, + features_1: np.ndarray, + radii_1: np.ndarray, + features_2: np.ndarray, + radii_2: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Evaluate precision and recall efficiently. + + :param features_1: [N1 x D] feature vectors for reference batch. + :param radii_1: [N1 x K1] radii for reference vectors. + :param features_2: [N2 x D] feature vectors for the other batch. + :param radii_2: [N x K2] radii for other vectors. + :return: a tuple of arrays for (precision, recall): + - precision: an np.ndarray of length K1 + - recall: an np.ndarray of length K2 + """ + features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) + features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) + for begin_1 in range(0, len(features_1), self.row_batch_size): + end_1 = begin_1 + self.row_batch_size + batch_1 = features_1[begin_1:end_1] + for begin_2 in range(0, len(features_2), self.col_batch_size): + end_2 = begin_2 + self.col_batch_size + batch_2 = features_2[begin_2:end_2] + batch_1_in, batch_2_in = self.distance_block.less_thans( + batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] + ) + features_1_status[begin_1:end_1] |= batch_1_in + features_2_status[begin_2:end_2] |= batch_2_in + return ( + np.mean(features_2_status.astype(np.float64), axis=0), + np.mean(features_1_status.astype(np.float64), axis=0), + ) + + +class DistanceBlock: + """ + Calculate pairwise distances between vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 + """ + + def __init__(self, session): + self.session = session + + # Initialize TF graph to calculate pairwise distances. + with session.graph.as_default(): + self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) + self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) + distance_block_16 = _batch_pairwise_distances( + tf.cast(self._features_batch1, tf.float16), + tf.cast(self._features_batch2, tf.float16), + ) + self.distance_block = tf.cond( + tf.reduce_all(tf.math.is_finite(distance_block_16)), + lambda: tf.cast(distance_block_16, tf.float32), + lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), + ) + + # Extra logic for less thans. + self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) + self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) + dist32 = tf.cast(self.distance_block, tf.float32)[..., None] + self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) + self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) + + def pairwise_distances(self, U, V): + """ + Evaluate pairwise distances between two batches of feature vectors. + """ + return self.session.run( + self.distance_block, + feed_dict={self._features_batch1: U, self._features_batch2: V}, + ) + + def less_thans(self, batch_1, radii_1, batch_2, radii_2): + return self.session.run( + [self._batch_1_in, self._batch_2_in], + feed_dict={ + self._features_batch1: batch_1, + self._features_batch2: batch_2, + self._radii1: radii_1, + self._radii2: radii_2, + }, + ) + + +def _batch_pairwise_distances(U, V): + """ + Compute pairwise distances between two batches of feature vectors. + """ + with tf.variable_scope("pairwise_dist_block"): + # Squared norms of each row in U and V. + norm_u = tf.reduce_sum(tf.square(U), 1) + norm_v = tf.reduce_sum(tf.square(V), 1) + + # norm_u as a column and norm_v as a row vectors. + norm_u = tf.reshape(norm_u, [-1, 1]) + norm_v = tf.reshape(norm_v, [1, -1]) + + # Pairwise squared Euclidean distances. + D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) + + return D + + +class NpzArrayReader(ABC): + @abstractmethod + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + pass + + @abstractmethod + def remaining(self) -> int: + pass + + def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: + def gen_fn(): + while True: + batch = self.read_batch(batch_size) + if batch is None: + break + yield batch + + rem = self.remaining() + num_batches = rem // batch_size + int(rem % batch_size != 0) + return BatchIterator(gen_fn, num_batches) + + +class BatchIterator: + def __init__(self, gen_fn, length): + self.gen_fn = gen_fn + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen_fn() + + +class StreamingNpzArrayReader(NpzArrayReader): + def __init__(self, arr_f, shape, dtype): + self.arr_f = arr_f + self.shape = shape + self.dtype = dtype + self.idx = 0 + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.shape[0]: + return None + + bs = min(batch_size, self.shape[0] - self.idx) + self.idx += bs + + if self.dtype.itemsize == 0: + return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) + + read_count = bs * np.prod(self.shape[1:]) + read_size = int(read_count * self.dtype.itemsize) + data = _read_bytes(self.arr_f, read_size, "array data") + return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) + + def remaining(self) -> int: + return max(0, self.shape[0] - self.idx) + + +class MemoryNpzArrayReader(NpzArrayReader): + def __init__(self, arr): + self.arr = arr + self.idx = 0 + + @classmethod + def load(cls, path: str, arr_name: str): + with open(path, "rb") as f: + arr = np.load(f)[arr_name] + return cls(arr) + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.arr.shape[0]: + return None + + res = self.arr[self.idx : self.idx + batch_size] + self.idx += batch_size + return res + + def remaining(self) -> int: + return max(0, self.arr.shape[0] - self.idx) + + +@contextmanager +def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: + with _open_npy_file(path, arr_name) as arr_f: + version = np.lib.format.read_magic(arr_f) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(arr_f) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(arr_f) + else: + yield MemoryNpzArrayReader.load(path, arr_name) + return + shape, fortran, dtype = header + if fortran or dtype.hasobject: + yield MemoryNpzArrayReader.load(path, arr_name) + else: + yield StreamingNpzArrayReader(arr_f, shape, dtype) + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 + + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on + # would-block, python2 file will truncate, probably nothing can be + # done about that. note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg % (error_template, size, len(data))) + else: + return data + + +@contextmanager +def _open_npy_file(path: str, arr_name: str): + with open(path, "rb") as f: + with zipfile.ZipFile(f, "r") as zip_f: + if f"{arr_name}.npy" not in zip_f.namelist(): + raise ValueError(f"missing {arr_name} in npz file") + with zip_f.open(f"{arr_name}.npy", "r") as arr_f: + yield arr_f + + +def _download_inception_model(): + if os.path.exists(INCEPTION_V3_PATH): + return + print("downloading InceptionV3 model...") + with requests.get(INCEPTION_V3_URL, stream=True) as r: + r.raise_for_status() + tmp_path = INCEPTION_V3_PATH + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) + os.rename(tmp_path, INCEPTION_V3_PATH) + + +def _create_feature_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + pool3, spatial = tf.import_graph_def( + graph_def, + input_map={f"ExpandDims:0": input_batch}, + return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], + name=prefix, + ) + _update_shapes(pool3) + spatial = spatial[..., :7] + return pool3, spatial + + +def _create_softmax_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + (matmul,) = tf.import_graph_def( + graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + ) + w = matmul.inputs[1] + logits = tf.matmul(input_batch, w) + return tf.nn.softmax(logits) + + +def _update_shapes(pool3): + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 + ops = pool3.graph.get_operations() + for op in ops: + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: # pylint: disable=protected-access + # shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] # TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__["_shape_val"] = tf.TensorShape(new_shape) + return pool3 + + +def _numpy_partition(arr, kth, **kwargs): + num_workers = min(cpu_count(), len(arr)) + chunk_size = len(arr) // num_workers + extra = len(arr) % num_workers + + start_idx = 0 + batches = [] + for i in range(num_workers): + size = chunk_size + (1 if i < extra else 0) + batches.append(arr[start_idx : start_idx + size]) + start_idx += size + + with ThreadPool(num_workers) as pool: + return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) + + +if __name__ == "__main__": + print(REQUIREMENTS) + main() diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/evaluate_perceptualsim.py b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/evaluate_perceptualsim.py new file mode 100644 index 0000000..c85fef9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/evaluate_perceptualsim.py @@ -0,0 +1,630 @@ +import argparse +import glob +import os +from tqdm import tqdm +from collections import namedtuple + +import numpy as np +import torch +import torchvision.transforms as transforms +from torchvision import models +from PIL import Image + +from ldm.modules.evaluate.ssim import ssim + + +transform = transforms.Compose([transforms.ToTensor()]) + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view( + in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3] + ) + return in_feat / (norm_factor.expand_as(in_feat) + eps) + + +def cos_sim(in0, in1): + in0_norm = normalize_tensor(in0) + in1_norm = normalize_tensor(in1) + N = in0.size()[0] + X = in0.size()[2] + Y = in0.size()[3] + + return torch.mean( + torch.mean( + torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2 + ).view(N, 1, 1, Y), + dim=3, + ).view(N) + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = models.squeezenet1_1( + pretrained=pretrained + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple( + "SqueezeOutputs", + ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], + ) + out = vgg_outputs( + h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7 + ) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = models.alexnet( + pretrained=pretrained + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple( + "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] + ) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", + ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"], + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if num == 18: + self.net = models.resnet18(pretrained=pretrained) + elif num == 34: + self.net = models.resnet34(pretrained=pretrained) + elif num == 50: + self.net = models.resnet50(pretrained=pretrained) + elif num == 101: + self.net = models.resnet101(pretrained=pretrained) + elif num == 152: + self.net = models.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple( + "Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"] + ) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out + +# Off-the-shelf deep network +class PNet(torch.nn.Module): + """Pre-trained network with all channels equally weighted by default""" + + def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True): + super(PNet, self).__init__() + + self.use_gpu = use_gpu + + self.pnet_type = pnet_type + self.pnet_rand = pnet_rand + + self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1) + self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1) + + if self.pnet_type in ["vgg", "vgg16"]: + self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False) + elif self.pnet_type == "alex": + self.net = alexnet( + pretrained=not self.pnet_rand, requires_grad=False + ) + elif self.pnet_type[:-2] == "resnet": + self.net = resnet( + pretrained=not self.pnet_rand, + requires_grad=False, + num=int(self.pnet_type[-2:]), + ) + elif self.pnet_type == "squeeze": + self.net = squeezenet( + pretrained=not self.pnet_rand, requires_grad=False + ) + + self.L = self.net.N_slices + + if use_gpu: + self.net.cuda() + self.shift = self.shift.cuda() + self.scale = self.scale.cuda() + + def forward(self, in0, in1, retPerLayer=False): + in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + + outs0 = self.net.forward(in0_sc) + outs1 = self.net.forward(in1_sc) + + if retPerLayer: + all_scores = [] + for (kk, out0) in enumerate(outs0): + cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk]) + if kk == 0: + val = 1.0 * cur_score + else: + val = val + cur_score + if retPerLayer: + all_scores += [cur_score] + + if retPerLayer: + return (val, all_scores) + else: + return val + + + + +# The SSIM metric +def ssim_metric(img1, img2, mask=None): + return ssim(img1, img2, mask=mask, size_average=False) + + +# The PSNR metric +def psnr(img1, img2, mask=None,reshape=False): + b = img1.size(0) + if not (mask is None): + b = img1.size(0) + mse_err = (img1 - img2).pow(2) * mask + if reshape: + mse_err = mse_err.reshape(b, -1).sum(dim=1) / ( + 3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1) + ) + else: + mse_err = mse_err.view(b, -1).sum(dim=1) / ( + 3 * mask.view(b, -1).sum(dim=1).clamp(min=1) + ) + else: + if reshape: + mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1) + else: + mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) + + psnr = 10 * (1 / mse_err).log10() + return psnr + + +# The perceptual similarity metric +def perceptual_sim(img1, img2, vgg16): + # First extract features + dist = vgg16(img1 * 2 - 1, img2 * 2 - 1) + + return dist + +def load_img(img_name, size=None): + try: + img = Image.open(img_name) + + if type(size) == int: + img = img.resize((size, size)) + elif size is not None: + img = img.resize((size[1], size[0])) + + img = transform(img).cuda() + img = img.unsqueeze(0) + except Exception as e: + print("Failed at loading %s " % img_name) + print(e) + img = torch.zeros(1, 3, 256, 256).cuda() + raise + return img + + +def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + folders = os.listdir(folder) + for i, f in tqdm(enumerate(sorted(folders))): + pred_imgs = glob.glob(folder + f + "/" + pred_img) + tgt_imgs = glob.glob(folder + f + "/" + tgt_img) + assert len(tgt_imgs) == 1 + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list, + take_every_other, + simple_format=True): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + equal_count = 0 + ambig_count = 0 + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + assert len(pred_imgs)>0 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + if psnr_sim != np.float("inf"): + values_psnr += [psnr_sim] + else: + if torch.allclose(p_img, t_img): + equal_count += 1 + print("{} equal src and wrp images.".format(equal_count)) + else: + ambig_count += 1 + print("{} ambiguous src and wrp images.".format(ambig_count)) + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + if simple_format: + # just to make yaml formatting readable + return { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + } + else: + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list, + take_every_other, resize=False): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + individual_percsim = [] + individual_ssim = [] + individual_psnr = [] + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + assert False + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + sample_percsim = list() + sample_ssim = list() + sample_psnr = list() + for p_img in pred_imgs: + if resize: + t_img = load_img(tgt_imgs[0], size=(256,256)) + else: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + sample_percsim.append(t_perc_sim) + perc_sim = min(perc_sim, t_perc_sim) + + t_ssim = ssim_metric(p_img, t_img).item() + sample_ssim.append(t_ssim) + ssim_sim = max(ssim_sim, t_ssim) + + t_psnr = psnr(p_img, t_img).item() + sample_psnr.append(t_psnr) + psnr_sim = max(psnr_sim, t_psnr) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + individual_percsim.append(sample_percsim) + individual_ssim.append(sample_ssim) + individual_psnr.append(sample_psnr) + + if take_every_other: + assert False, "Do this later, after specifying topk to get proper results" + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + individual_percsim = np.array(individual_percsim) + individual_psnr = np.array(individual_psnr) + individual_ssim = np.array(individual_ssim) + + return { + "avg_of_best": { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + }, + "individual": { + "PSIM": individual_percsim, + "PSNR": individual_psnr, + "SSIM": individual_ssim, + } + } + + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument("--folder", type=str, default="") + args.add_argument("--pred_image", type=str, default="") + args.add_argument("--target_image", type=str, default="") + args.add_argument("--take_every_other", action="store_true", default=False) + args.add_argument("--output_file", type=str, default="") + + opts = args.parse_args() + + folder = opts.folder + pred_img = opts.pred_image + tgt_img = opts.target_image + + results = compute_perceptual_similarity( + folder, pred_img, tgt_img, opts.take_every_other + ) + + f = open(opts.output_file, 'w') + for key in results: + print("%s for %s: \n" % (key, opts.folder)) + print( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) + + f.write("%s for %s: \n" % (key, opts.folder)) + f.write( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) + + f.close() diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/frechet_video_distance.py b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/frechet_video_distance.py new file mode 100644 index 0000000..d9e13c4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/frechet_video_distance.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python2, python3 +"""Minimal Reference implementation for the Frechet Video Distance (FVD). + +FVD is a metric for the quality of video generation models. It is inspired by +the FID (Frechet Inception Distance) used for images, but uses a different +embedding to be better suitable for videos. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import six +import tensorflow.compat.v1 as tf +import tensorflow_gan as tfgan +import tensorflow_hub as hub + + +def preprocess(videos, target_resolution): + """Runs some preprocessing on the videos for I3D model. + + Args: + videos: [batch_size, num_frames, height, width, depth] The videos to be + preprocessed. We don't care about the specific dtype of the videos, it can + be anything that tf.image.resize_bilinear accepts. Values are expected to + be in the range 0-255. + target_resolution: (width, height): target video resolution + + Returns: + videos: [batch_size, num_frames, height, width, depth] + """ + videos_shape = list(videos.shape) + all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) + resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) + target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] + output_videos = tf.reshape(resized_videos, target_shape) + scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 + return scaled_videos + + +def _is_in_graph(tensor_name): + """Checks whether a given tensor does exists in the graph.""" + try: + tf.get_default_graph().get_tensor_by_name(tensor_name) + except KeyError: + return False + return True + + +def create_id3_embedding(videos,warmup=False,batch_size=16): + """Embeds the given videos using the Inflated 3D Convolution ne twork. + + Downloads the graph of the I3D from tf.hub and adds it to the graph on the + first call. + + Args: + videos: [batch_size, num_frames, height=224, width=224, depth=3]. + Expected range is [-1, 1]. + + Returns: + embedding: [batch_size, embedding_size]. embedding_size depends + on the model used. + + Raises: + ValueError: when a provided embedding_layer is not supported. + """ + + # batch_size = 16 + module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" + + + # Making sure that we import the graph separately for + # each different input video tensor. + module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( + videos.name).replace(":", "_") + + + + assert_ops = [ + tf.Assert( + tf.reduce_max(videos) <= 1.001, + ["max value in frame is > 1", videos]), + tf.Assert( + tf.reduce_min(videos) >= -1.001, + ["min value in frame is < -1", videos]), + tf.assert_equal( + tf.shape(videos)[0], + batch_size, ["invalid frame batch size: ", + tf.shape(videos)], + summarize=6), + ] + with tf.control_dependencies(assert_ops): + videos = tf.identity(videos) + + module_scope = "%s_apply_default/" % module_name + + # To check whether the module has already been loaded into the graph, we look + # for a given tensor name. If this tensor name exists, we assume the function + # has been called before and the graph was imported. Otherwise we import it. + # Note: in theory, the tensor could exist, but have wrong shapes. + # This will happen if create_id3_embedding is called with a frames_placehoder + # of wrong size/batch size, because even though that will throw a tf.Assert + # on graph-execution time, it will insert the tensor (with wrong shape) into + # the graph. This is why we need the following assert. + if warmup: + video_batch_size = int(videos.shape[0]) + assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + if not _is_in_graph(tensor_name): + i3d_model = hub.Module(module_spec, name=module_name) + i3d_model(videos) + + # gets the kinetics-i3d-400-logits layer + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) + return tensor + + +def calculate_fvd(real_activations, + generated_activations): + """Returns a list of ops that compute metrics as funcs of activations. + + Args: + real_activations: [num_samples, embedding_size] + generated_activations: [num_samples, embedding_size] + + Returns: + A scalar that contains the requested FVD. + """ + return tfgan.eval.frechet_classifier_distance_from_activations( + real_activations, generated_activations) diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/ssim.py b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/ssim.py new file mode 100644 index 0000000..4e8883c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/ssim.py @@ -0,0 +1,124 @@ +# MIT Licence + +# Methods to predict the SSIM, taken from +# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def _ssim( + img1, img2, window, window_size, channel, mask=None, size_average=True +): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) + - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = (0.01) ** 2 + C2 = (0.03) ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if not (mask is None): + b = mask.size(0) + ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask + ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( + dim=1 + ).clamp(min=1) + return ssim_map + + import pdb + + pdb.set_trace + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2, mask=None): + (_, channel, _, _) = img1.size() + + if ( + channel == self.channel + and self.window.data.type() == img1.data.type() + ): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim( + img1, + img2, + window, + self.window_size, + channel, + mask, + self.size_average, + ) + + +def ssim(img1, img2, window_size=11, mask=None, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, mask, size_average) diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/torch_frechet_video_distance.py b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/torch_frechet_video_distance.py new file mode 100644 index 0000000..04856b8 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/evaluate/torch_frechet_video_distance.py @@ -0,0 +1,294 @@ +# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! +import os +import numpy as np +import io +import re +import requests +import html +import hashlib +import urllib +import urllib.request +import scipy.linalg +import multiprocessing as mp +import glob + + +from tqdm import tqdm +from typing import Any, List, Tuple, Union, Dict, Callable + +from torchvision.io import read_video +import torch; torch.set_grad_enabled(False) +from einops import rearrange + +from nitro.util import isvideo + +def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: + print('Calculate frechet distance...') + m = np.square(mu_sample - mu_ref).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) + + return float(fid) + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + + return mu, sigma + + +def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + +def load_video(ip): + vid, *_ = read_video(ip) + vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) + return vid + +def get_data_from_str(input_str,nprc = None): + assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' + vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) + print(f'Found {len(vid_filelist)} videos in dir {input_str}') + + if nprc is None: + try: + nprc = mp.cpu_count() + except NotImplementedError: + print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') + nprc = 1 + + pool = mp.Pool(processes=nprc) + + vids = [] + for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): + vids.append(v) + + + vids = torch.stack(vids,dim=0).float() + + return vids + +def get_stats(stats): + assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' + + print(f'Using precomputed statistics under {stats}') + stats = np.load(stats) + stats = {key: stats[key] for key in stats.files} + + return stats + + + + +@torch.no_grad() +def compute_fvd(ref_input, sample_input, bs=32, + ref_stats=None, + sample_stats=None, + nprc_load=None): + + + + calc_stats = ref_stats is None or sample_stats is None + + if calc_stats: + + only_ref = sample_stats is not None + only_sample = ref_stats is not None + + + if isinstance(ref_input,str) and not only_sample: + ref_input = get_data_from_str(ref_input,nprc_load) + + if isinstance(sample_input, str) and not only_ref: + sample_input = get_data_from_str(sample_input, nprc_load) + + stats = compute_statistics(sample_input,ref_input, + device='cuda' if torch.cuda.is_available() else 'cpu', + bs=bs, + only_ref=only_ref, + only_sample=only_sample) + + if only_ref: + stats.update(get_stats(sample_stats)) + elif only_sample: + stats.update(get_stats(ref_stats)) + + + + else: + stats = get_stats(sample_stats) + stats.update(get_stats(ref_stats)) + + fvd = compute_frechet_distance(**stats) + + return {'FVD' : fvd,} + + +@torch.no_grad() +def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: + detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' + detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer. + + with open_url(detector_url, verbose=False) as f: + detector = torch.jit.load(f).eval().to(device) + + + + assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' + + ref_embed, sample_embed = [], [] + + info = f'Computing I3D activations for FVD score with batch size {bs}' + + if only_ref: + + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + print(videos_real.shape) + + if videos_real.shape[0] % bs == 0: + n_secs = videos_real.shape[0] // bs + else: + n_secs = videos_real.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): + + feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + ref_embed.append(feats_ref) + + elif only_sample: + + if not isvideo(videos_fake): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + print(videos_fake.shape) + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): + feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + sample_embed.append(feats_sample) + + + else: + + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + + if not isvideo(videos_fake): + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) + + for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): + # print(ref_v.shape) + # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + + + feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + sample_embed.append(feats_sample) + ref_embed.append(feats_ref) + + out = dict() + if len(sample_embed) > 0: + sample_embed = np.concatenate(sample_embed,axis=0) + mu_sample, sigma_sample = compute_stats(sample_embed) + out.update({'mu_sample': mu_sample, + 'sigma_sample': sigma_sample}) + + if len(ref_embed) > 0: + ref_embed = np.concatenate(ref_embed,axis=0) + mu_ref, sigma_ref = compute_stats(ref_embed) + out.update({'mu_ref': mu_ref, + 'sigma_ref': sigma_ref}) + + + return out diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/__init__.py b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000..7836cad --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan.py b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000..32ef561 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan_light.py b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000..dfa7606 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils/test.png b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000..4249b43 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils/test.png differ diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils_image.py b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000..0175f15 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/losses/__init__.py b/stable-dreamfusion-3DPortrait/ldm/modules/losses/__init__.py new file mode 100644 index 0000000..876d7c5 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/losses/contperceptual.py b/stable-dreamfusion-3DPortrait/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000..672c1e3 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/losses/vqperceptual.py b/stable-dreamfusion-3DPortrait/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000..f699817 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/stable-dreamfusion-3DPortrait/ldm/modules/x_transformer.py b/stable-dreamfusion-3DPortrait/ldm/modules/x_transformer.py new file mode 100644 index 0000000..5fc15bf --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/helpers.py b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/helpers.py new file mode 100644 index 0000000..983baaa --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/helpers.py @@ -0,0 +1,121 @@ +# https://github.com/eladrich/pixel2style2pixel + +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/id_loss.py b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/id_loss.py new file mode 100644 index 0000000..e08ee09 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/id_loss.py @@ -0,0 +1,23 @@ +# https://github.com/eladrich/pixel2style2pixel +import torch +from torch import nn +from ldm.thirdp.psp.model_irse import Backbone + + +class IDFeatures(nn.Module): + def __init__(self, model_path): + super(IDFeatures, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def forward(self, x, crop=False): + # Not sure of the image range here + if crop: + x = torch.nn.functional.interpolate(x, (256, 256), mode="area") + x = x[:, :, 35:223, 32:220] + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats diff --git a/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/model_irse.py b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/model_irse.py new file mode 100644 index 0000000..21cedd2 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/thirdp/psp/model_irse.py @@ -0,0 +1,86 @@ +# https://github.com/eladrich/pixel2style2pixel + +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/ldm/util.py b/stable-dreamfusion-3DPortrait/ldm/util.py new file mode 100644 index 0000000..7dcad70 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/ldm/util.py @@ -0,0 +1,227 @@ +import importlib + +import torchvision +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import torch +import time +import cv2 + +import PIL + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width)/2 + bottom = (height + width)/2 + else: + + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/main.py b/stable-dreamfusion-3DPortrait/main.py new file mode 100644 index 0000000..9d1a9bf --- /dev/null +++ b/stable-dreamfusion-3DPortrait/main.py @@ -0,0 +1,413 @@ +import torch +import argparse +import pandas as pd +import sys + +from nerf.provider import NeRFDataset +from nerf.utils import * + +# torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + # See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre + class LoadFromFile (argparse.Action): + def __call__ (self, parser, namespace, values, option_string = None): + with values as f: + # parse arguments in the file and store them in the target namespace + parser.parse_args(f.read().split(), namespace) + + parser = argparse.ArgumentParser() + parser.add_argument('--file', type=open, action=LoadFromFile, help="specify a file filled with more arguments") + parser.add_argument('--text', default=None, help="text prompt") + parser.add_argument('--negative', default='', type=str, help="negative text prompt") + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray") + parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views") + parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs") + parser.add_argument('--test_interval', type=int, default=100, help="test on the test set every interval epochs") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', default=None) + + parser.add_argument('--image', default=None, help="image prompt") + parser.add_argument('--image_config', default=None, help="image config csv") + + parser.add_argument('--known_view_interval', type=int, default=4, help="train default view with RGB loss every & iters, only valid if --image is not None.") + + parser.add_argument('--IF', action='store_true', help="experimental: use DeepFloyd IF as the guidance model for nerf stage") + + parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model') + parser.add_argument('--guidance_scale', type=float, default=100, help="diffusion model classifier-free guidance scale") + + parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture") + parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh") + parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation") + + parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning") + parser.add_argument('--tet_grid_size', type=int, default=128, help="tet grid size") + parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet") + parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry") + + ## Perp-Neg options + parser.add_argument('--perpneg', action='store_true', help="use perp_neg") + parser.add_argument('--negative_w', type=float, default=-2, help="The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt") + parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt") + parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt") + + ### training options + parser.add_argument('--iters', type=int, default=10000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") + parser.add_argument('--ckpt', type=str, default='latest', help="possible options are ['latest', 'scratch', 'best', 'latest_model']") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=32, help="num steps up-sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") + parser.add_argument('--latent_iter_ratio', type=float, default=0.2, help="training iters that only use albedo shading") + parser.add_argument('--albedo_iter_ratio', type=float, default=0, help="training iters that only use albedo shading") + parser.add_argument('--min_ambient_ratio', type=float, default=0.1, help="minimum ambient ratio to use in lambertian shading") + parser.add_argument('--textureless_ratio', type=float, default=0.2, help="ratio of textureless shading") + parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses") + parser.add_argument('--jitter_center', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's center (camera location)") + parser.add_argument('--jitter_target', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')") + parser.add_argument('--jitter_up', type=float, default=0.02, help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')") + parser.add_argument('--uniform_sphere_rate', type=float, default=0, help="likelihood of sampling camera location uniformly on the sphere surface area") + parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it") + parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it") + # model options + parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], help="density activation function") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--blob_density', type=float, default=5, help="max (center) density for the density blob") + parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob") + # network backbone + parser.add_argument('--backbone', type=str, default='grid', choices=['grid_tcnn', 'grid', 'vanilla', 'grid_taichi'], help="nerf backbone") + parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer") + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + # try this if CUDA OOM + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled. + parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training") + parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training") + parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering") + parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help="random camera noise added to rays_o and rays_d") + parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning") + parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF") + + ### dataset options + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)") + parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.01, help="minimum near distance for camera") + + parser.add_argument('--radius_range', type=float, nargs='*', default=[3.0, 3.5], help="training camera radius range") + parser.add_argument('--theta_range', type=float, nargs='*', default=[45, 105], help="training camera range along the polar angles (i.e. up and down). See advanced.md for details.") + parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], help="training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.") + parser.add_argument('--fovy_range', type=float, nargs='*', default=[10, 30], help="training camera fovy range") + + parser.add_argument('--default_radius', type=float, default=3.2, help="radius for the default view") + parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view") + parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view") + parser.add_argument('--default_fovy', type=float, default=20, help="fovy for the default view") + + parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full") + parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, help="initial ratio of final range, used for progressive_view") + + parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level") + + parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region") + parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") + parser.add_argument('--t_range', type=float, nargs='*', default=[0.02, 0.98], help="stable diffusion time steps range") + parser.add_argument('--dont_override_stuff',action='store_true', help="Don't override t_range, etc.") + + + ### regularizations + parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy") + parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value") + parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation") + parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation") + parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale") + + parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness") + parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help="loss scale for mesh laplacian") + + parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS") + parser.add_argument('--lambda_rgb', type=float, default=1000, help="loss scale for RGB") + parser.add_argument('--lambda_mask', type=float, default=500, help="loss scale for mask (alpha)") + parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale for normal map") + parser.add_argument('--lambda_depth', type=float, default=10, help="loss scale for relative depth") + parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, help="loss scale for 2D normal image smoothness") + parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, help="loss scale for 3D normal image smoothness") + + ### debugging options + parser.add_argument('--save_guidance', action='store_true', help="save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!") + parser.add_argument('--save_guidance_interval', type=int, default=10, help="save guidance every X step") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=800, help="GUI width") + parser.add_argument('--H', type=int, default=800, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=20, help="default GUI camera fovy") + parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") + parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth") + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") + + parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123") + parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', help="ckpt for zero123") + parser.add_argument('--zero123_grad_scale', type=str, default='angle', help="whether to scale the gradients based on 'angle' or 'None'") + + parser.add_argument('--dataset_size_train', type=int, default=100, help="Length of train dataset i.e. # of iterations per epoch") + parser.add_argument('--dataset_size_valid', type=int, default=8, help="# of frames to render in the turntable video in validation") + parser.add_argument('--dataset_size_test', type=int, default=100, help="# of frames to render in the turntable video at test time") + + parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level") + parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level") + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.cuda_ray = True + + elif opt.O2: + opt.fp16 = True + opt.backbone = 'vanilla' + opt.progressive_level = True + + if opt.IF: + if 'SD' in opt.guidance: + opt.guidance.remove('SD') + opt.guidance.append('IF') + opt.latent_iter_ratio = 0 # must not do as_latent + + opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], [] + opt.default_zero123_w = 1 + + opt.exp_start_iter = opt.exp_start_iter or 0 + opt.exp_end_iter = opt.exp_end_iter or opt.iters + + # parameters for image-conditioned generation + if opt.image is not None or opt.image_config is not None: + + if opt.text is None: + # use zero123 guidance model when only providing image + opt.guidance = ['zero123'] + if not opt.dont_override_stuff: + opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov + opt.guidance_scale = 5 + opt.lambda_3d_normal_smooth = 10 + else: + # use stable-diffusion when providing both text and image + opt.guidance = ['SD', 'clip'] + + if not opt.dont_override_stuff: + opt.guidance_scale = 10 + opt.t_range = [0.2, 0.6] + opt.known_view_interval = 2 + opt.lambda_3d_normal_smooth = 20 + opt.bg_radius = -1 + + # smoothness + opt.lambda_entropy = 1 + opt.lambda_orient = 1 + + # latent warmup is not needed + opt.latent_iter_ratio = 0 + if not opt.dont_override_stuff: + opt.albedo_iter_ratio = 0 + + # make shape init more stable + opt.progressive_view = True + opt.progressive_level = True + + if opt.image is not None: + opt.images += [opt.image] + opt.ref_radii += [opt.default_radius] + opt.ref_polars += [opt.default_polar] + opt.ref_azimuths += [opt.default_azimuth] + opt.zero123_ws += [opt.default_zero123_w] + + if opt.image_config is not None: + # for multiview (zero123) + conf = pd.read_csv(opt.image_config, skipinitialspace=True) + opt.images += list(conf.image) + opt.ref_radii += list(conf.radius) + opt.ref_polars += list(conf.polar) + opt.ref_azimuths += list(conf.azimuth) + opt.zero123_ws += list(conf.zero123_weight) + if opt.image is None: + opt.default_radius = opt.ref_radii[0] + opt.default_polar = opt.ref_polars[0] + opt.default_azimuth = opt.ref_azimuths[0] + opt.default_zero123_w = opt.zero123_ws[0] + + # reset to None + if len(opt.images) == 0: + opt.images = None + + # default parameters for finetuning + if opt.dmtet: + + opt.h = int(opt.h * opt.dmtet_reso_scale) + opt.w = int(opt.w * opt.dmtet_reso_scale) + opt.known_view_scale = 1 + + if not opt.dont_override_stuff: + opt.t_range = [0.02, 0.50] # ref: magic3D + + if opt.images is not None: + + opt.lambda_normal = 0 + opt.lambda_depth = 0 + + if opt.text is not None and not opt.dont_override_stuff: + opt.t_range = [0.20, 0.50] + + # assume finetuning + opt.latent_iter_ratio = 0 + opt.albedo_iter_ratio = 0 + opt.progressive_view = False + # opt.progressive_level = False + + # record full range for progressive view expansion + if opt.progressive_view: + if not opt.dont_override_stuff: + # disable as they disturb progressive view + opt.jitter_pose = False + + opt.uniform_sphere_rate = 0 + # back up full range + opt.full_radius_range = opt.radius_range + opt.full_theta_range = opt.theta_range + opt.full_phi_range = opt.phi_range + opt.full_fovy_range = opt.fovy_range + + if opt.backbone == 'vanilla': + from nerf.network import NeRFNetwork + elif opt.backbone == 'grid': + from nerf.network_grid import NeRFNetwork + elif opt.backbone == 'grid_tcnn': + from nerf.network_grid_tcnn import NeRFNetwork + elif opt.backbone == 'grid_taichi': + opt.cuda_ray = False + opt.taichi_ray = True + import taichi as ti + from nerf.network_grid_taichi import NeRFNetwork + taichi_half2_opt = True + taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0} + if taichi_half2_opt: + taichi_init_args["half2_vectorization"] = True + ti.init(**taichi_init_args) + else: + raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!') + + print(opt) + + if opt.seed is not None: + seed_everything(int(opt.seed)) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = NeRFNetwork(opt).to(device) + + if opt.dmtet and opt.init_with != '': + if opt.init_with.endswith('.pth'): + # load pretrained weights to init dmtet + state_dict = torch.load(opt.init_with, map_location=device) + model.load_state_dict(state_dict['model'], strict=False) + if opt.cuda_ray: + model.mean_density = state_dict['mean_density'] + model.init_tet() + else: + # assume a mesh to init dmtet (experimental, not working well now!) + import trimesh + mesh = trimesh.load(opt.init_with, force='mesh', skip_material=True, process=False) + model.init_tet(mesh=mesh) + + print(model) + + if opt.six_views: + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + test_loader = NeRFDataset(opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1) + trainer.test(test_loader, write_video=False) + + if opt.save_mesh: + trainer.save_mesh() + + elif opt.test: + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + trainer.test(test_loader) + + if opt.save_mesh: + trainer.save_mesh() + + else: + + train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader() + + if opt.optim == 'adan': + from optimizer import Adan + # Adan usually requires a larger LR + optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) + else: # adam + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + + if opt.backbone == 'vanilla': + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + else: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed + # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + guidance = nn.ModuleDict() + + if 'SD' in opt.guidance: + from guidance.sd_utils import StableDiffusion + guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range) + + if 'IF' in opt.guidance: + from guidance.if_utils import IF + guidance['IF'] = IF(device, opt.vram_O, opt.t_range) + + if 'zero123' in opt.guidance: + from guidance.zero123_utils import Zero123 + guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt) + + if 'clip' in opt.guidance: + from guidance.clip_utils import CLIP + guidance['clip'] = CLIP(device) + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True) + + trainer.default_view_data = train_loader._data.get_default_view_data() + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader(batch_size=1) + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, test_loader, max_epoch) + + if opt.save_mesh: + trainer.save_mesh() diff --git a/stable-dreamfusion-3DPortrait/main_3DPortraitGAN.py b/stable-dreamfusion-3DPortrait/main_3DPortraitGAN.py new file mode 100644 index 0000000..15a2359 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/main_3DPortraitGAN.py @@ -0,0 +1,474 @@ +import os + +import torch +import argparse +import pandas as pd +import sys + +from nerf.provider import NeRFDataset +from nerf.utils import * + + +if __name__ == '__main__': + # See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre + class LoadFromFile (argparse.Action): + def __call__ (self, parser, namespace, values, option_string = None): + with values as f: + # parse arguments in the file and store them in the target namespace + parser.parse_args(f.read().split(), namespace) + + parser = argparse.ArgumentParser() + parser.add_argument('--file', type=open, action=LoadFromFile, help="specify a file filled with more arguments") + parser.add_argument('--text', default=None, help="text prompt") + parser.add_argument('--negative', default='', type=str, help="negative text prompt") + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray") + parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views") + parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs") + parser.add_argument('--test_interval', type=int, default=10, help="test on the test set every interval epochs") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', default=None) + + parser.add_argument('--image', default=None, help="image prompt") + parser.add_argument('--image_config', default=None, help="image config csv") + + parser.add_argument('--known_view_interval', type=int, default=4, help="train default view with RGB loss every & iters, only valid if --image is not None.") + + parser.add_argument('--IF', action='store_true', help="experimental: use DeepFloyd IF as the guidance model for nerf stage") + + parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model') + parser.add_argument('--guidance_scale', type=float, default=100, help="diffusion model classifier-free guidance scale") + + parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture") + parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh") + parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation") + + parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning") + parser.add_argument('--tet_grid_size', type=int, default=128, help="tet grid size") + parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet") + parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry") + + ## Perp-Neg options + parser.add_argument('--perpneg', action='store_true', help="use perp_neg") + parser.add_argument('--negative_w', type=float, default=-2, help="The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt") + parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt") + parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt") + + + ## Trigrid options + parser.add_argument('--trigrid_path', type=str, default='', help="path to trigrid") + parser.add_argument('--trigrid_decoder_ckpt', type=str, default='', help="path to trigrid decoder ckpt") + parser.add_argument('--train_decoder', action='store_true', help="train trigrid decoder") + parser.add_argument('--learnable_bg', action='store_true', help="Learnable background") + parser.add_argument('--trigrid_lr_ratio',type=float, nargs='+', default=[100,100,100,100,20,20,20], help="stable diffusion time steps range") + parser.add_argument('--scheduler_annealing', action='store_true', help="annealing scheduler") + + ### training options + parser.add_argument('--iters', type=int, default=10000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") + parser.add_argument('--ckpt', type=str, default='latest', help="possible options are ['latest', 'scratch', 'best', 'latest_model']") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=32, help="num steps up-sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") + parser.add_argument('--latent_iter_ratio', type=float, default=0.2, help="training iters that only use albedo shading") + parser.add_argument('--albedo_iter_ratio', type=float, default=0, help="training iters that only use albedo shading") + parser.add_argument('--min_ambient_ratio', type=float, default=0.1, help="minimum ambient ratio to use in lambertian shading") + parser.add_argument('--textureless_ratio', type=float, default=0.2, help="ratio of textureless shading") + parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses") + parser.add_argument('--jitter_center', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's center (camera location)") + parser.add_argument('--jitter_target', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')") + parser.add_argument('--jitter_up', type=float, default=0.02, help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')") + parser.add_argument('--uniform_sphere_rate', type=float, default=0, help="likelihood of sampling camera location uniformly on the sphere surface area") + parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it") + parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it") + # model options + parser.add_argument('--bg_radius', type=float, default=3.0, help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], help="density activation function") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--blob_density', type=float, default=5, help="max (center) density for the density blob") + parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob") + # network backbone + parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer") + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + # try this if CUDA OOM + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled. + parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training") + parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training") + parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering") + parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help="random camera noise added to rays_o and rays_d") + parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning") + parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF") + + ### dataset options + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)") + parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.01, help="minimum near distance for camera") + + parser.add_argument('--radius_range', type=float, nargs='*', default=[2.6, 2.8], help="training camera radius range") + parser.add_argument('--theta_range', type=float, nargs='*', default=[60, 105], help="training camera range along the polar angles (i.e. up and down). See advanced.md for details.") + parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], help="training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.") + parser.add_argument('--fovy_range', type=float, nargs='*', default=[11, 13], help="training camera fovy range") + + parser.add_argument('--default_radius', type=float, default=2.7, help="radius for the default view") + parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view") + parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view") + parser.add_argument('--default_fovy', type=float, default=12., help="fovy for the default view") + + parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full") + parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, help="initial ratio of final range, used for progressive_view") + + parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level") + + parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region") + parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") + parser.add_argument('--t_range', type=float, nargs='+', default=[0.02, 0.98], help="stable diffusion time steps range") + parser.add_argument('--dont_override_stuff',action='store_true', help="Don't override t_range, etc.") + + + ### regularizations + parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy") + parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value") + parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation") + parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation") + parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale") + + parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness") + parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help="loss scale for mesh laplacian") + + parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS") + parser.add_argument('--lambda_rgb', type=float, default=1000, help="loss scale for RGB") + parser.add_argument('--lambda_mask', type=float, default=500, help="loss scale for mask (alpha)") + parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale for normal map") + parser.add_argument('--lambda_depth', type=float, default=10, help="loss scale for relative depth") + parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, help="loss scale for 2D normal image smoothness") + parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, help="loss scale for 3D normal image smoothness") + + ### debugging options + parser.add_argument('--save_guidance', action='store_true', help="save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!") + parser.add_argument('--save_guidance_interval', type=int, default=10, help="save guidance every X step") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=800, help="GUI width") + parser.add_argument('--H', type=int, default=800, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=20, help="default GUI camera fovy") + parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") + parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth") + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") + + parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123") + parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', help="ckpt for zero123") + parser.add_argument('--zero123_grad_scale', type=str, default='angle', help="whether to scale the gradients based on 'angle' or 'None'") + + parser.add_argument('--dataset_size_train', type=int, default=100, help="Length of train dataset i.e. # of iterations per epoch") + parser.add_argument('--dataset_size_valid', type=int, default=8, help="# of frames to render in the turntable video in validation") + parser.add_argument('--dataset_size_test', type=int, default=100, help="# of frames to render in the turntable video at test time") + + parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level") + parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level") + + opt = parser.parse_args() + if opt.O: + raise NotImplementedError + opt.fp16 = True + opt.cuda_ray = True + + elif opt.O2: + raise NotImplementedError + opt.fp16 = True + opt.progressive_level = True + + if opt.IF: + if 'SD' in opt.guidance: + opt.guidance.remove('SD') + opt.guidance.append('IF') + opt.latent_iter_ratio = 0 # must not do as_latent + + opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], [] + opt.default_zero123_w = 1 + + opt.exp_start_iter = opt.exp_start_iter or 0 + opt.exp_end_iter = opt.exp_end_iter or opt.iters + + # parameters for image-conditioned generation + if opt.image is not None or opt.image_config is not None: + + if opt.text is None: + # use zero123 guidance model when only providing image + opt.guidance = ['zero123'] + if not opt.dont_override_stuff: + opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov + opt.guidance_scale = 5 + opt.lambda_3d_normal_smooth = 10 + else: + # use stable-diffusion when providing both text and image + opt.guidance = ['SD', 'clip'] + + if not opt.dont_override_stuff: + opt.guidance_scale = 10 + opt.t_range = [0.2, 0.6] + opt.known_view_interval = 2 + opt.lambda_3d_normal_smooth = 20 + opt.bg_radius = -1 + + # smoothness + opt.lambda_entropy = 1 + opt.lambda_orient = 1 + + # latent warmup is not needed + opt.latent_iter_ratio = 0 + if not opt.dont_override_stuff: + opt.albedo_iter_ratio = 0 + + # make shape init more stable + opt.progressive_view = True + opt.progressive_level = True + + if opt.image is not None: + opt.images += [opt.image] + opt.ref_radii += [opt.default_radius] + opt.ref_polars += [opt.default_polar] + opt.ref_azimuths += [opt.default_azimuth] + opt.zero123_ws += [opt.default_zero123_w] + + if opt.image_config is not None: + # for multiview (zero123) + conf = pd.read_csv(opt.image_config, skipinitialspace=True) + opt.images += list(conf.image) + opt.ref_radii += list(conf.radius) + opt.ref_polars += list(conf.polar) + opt.ref_azimuths += list(conf.azimuth) + opt.zero123_ws += list(conf.zero123_weight) + if opt.image is None: + opt.default_radius = opt.ref_radii[0] + opt.default_polar = opt.ref_polars[0] + opt.default_azimuth = opt.ref_azimuths[0] + opt.default_zero123_w = opt.zero123_ws[0] + + # reset to None + if len(opt.images) == 0: + opt.images = None + + + if opt.learnable_bg: + assert opt.bg_radius> max(opt.radius_range), f"bg_radius must be larger than max(radius_range) = {max(opt.radius_range)}" + + + # default parameters for finetuning + if opt.dmtet: + + opt.h = int(opt.h * opt.dmtet_reso_scale) + opt.w = int(opt.w * opt.dmtet_reso_scale) + opt.known_view_scale = 1 + + if not opt.dont_override_stuff: + opt.t_range = [0.02, 0.50] # ref: magic3D + + if opt.images is not None: + + opt.lambda_normal = 0 + opt.lambda_depth = 0 + + if opt.text is not None and not opt.dont_override_stuff: + opt.t_range = [0.20, 0.50] + + # assume finetuning + opt.latent_iter_ratio = 0 + opt.albedo_iter_ratio = 0 + opt.progressive_view = False + # opt.progressive_level = False + os.makedirs(opt.workspace, exist_ok=True) + # record full range for progressive view expansion + if opt.progressive_view: + if not opt.dont_override_stuff: + # disable as they disturb progressive view + opt.jitter_pose = False + + opt.uniform_sphere_rate = 0 + # back up full range + opt.full_radius_range = opt.radius_range + opt.full_theta_range = opt.theta_range + opt.full_phi_range = opt.phi_range + opt.full_fovy_range = opt.fovy_range + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + from nerf.network_trigrid_heirarchy import NeRFNetwork + + # load plane from pkl + if os.path.isfile(opt.trigrid_path): + print(f'loading plane from {opt.trigrid_path}...') + import pickle + threeDRepresentation = {} + threeDRepresentation_shapes = {} + with open(opt.trigrid_path, 'rb') as f: + data = pickle.load(f) + ws = data['ws'] + threeDRepresentation_shapes['ws'] = ws.shape + for key in data: + if 'trigrids' in key: + threeDRepresentation[key] = data[key].view(1, 3, -1, data[key].shape[-2], data[key].shape[-1]) + + threeDRepresentation_shapes[key] =threeDRepresentation[key].shape + + + print('save trigrid to workspace...') + shutil.copy(opt.trigrid_path, os.path.join(opt.workspace, 'trigrid.pkl')) + + + + model = NeRFNetwork( + opt=opt, + device=device, + trigrid_shapes=threeDRepresentation_shapes + ) + + # load + if os.path.isfile(opt.trigrid_decoder_ckpt): + print(f'loading trigrid_renderer from {opt.trigrid_decoder_ckpt}...') + ckpt = torch.load(opt.trigrid_decoder_ckpt, map_location=lambda storage, loc: storage) + # ckpt = {'params': params, 'state_dict': ckpt} + state_dict = ckpt['state_dict'] + params = ckpt['params'] + print(f'loading params: {params}') + + model.model.load_state_dict(state_dict) + # + model.model.rendering_kwargs['depth_resolution'] = 48 + model.model.rendering_kwargs['depth_resolution_importance'] = 48 + # + model.model.rendering_kwargs['ray_start'] = 2.35 + # + # load plane from pkl + if os.path.isfile(opt.trigrid_path): + model.load_state_dict( + threeDRepresentation, strict=False + ) + + print(opt) + + if opt.seed is not None: + seed_everything(int(opt.seed)) + + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # + + + + + print(model) + + if opt.six_views: + raise NotImplementedError(f'--six_views {opt.six_views} is not implemented!') + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + test_loader = NeRFDataset(opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1) + trainer.test(test_loader, write_video=False) + + if opt.save_mesh: + trainer.save_mesh() + + elif opt.test: + raise NotImplementedError(f'--test {opt.test} is not implemented!') + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + trainer.test(test_loader) + + if opt.save_mesh: + trainer.save_mesh() + + else: + + train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader() + # data = { + # 'H': self.H, + # 'W': self.W, + # 'rays_o': rays['rays_o'], + # 'rays_d': rays['rays_d'], + # 'dir': dirs, + # 'mvp': mvp, + # 'polar': delta_polar, + # 'azimuth': delta_azimuth, + # 'radius': delta_radius, + # } + if opt.optim == 'adan': + from optimizer import Adan + + # Adan usually requires a larger LR + optimizer = lambda model: Adan(model.get_params(5 * opt.lr,trigrid_lr_ratio = opt.trigrid_lr_ratio), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, + foreach=False) + else: # adam + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr,trigrid_lr_ratio = opt.trigrid_lr_ratio), betas=(0.9, 0.99), eps=1e-15) + + if opt.scheduler_annealing: + print('=============== use scheduler: annealing') + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, + lambda iter: 0.1 ** min(iter / opt.iters, 1)) + else: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed + # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + guidance = nn.ModuleDict() + + if 'SD' in opt.guidance: + from guidance.sd_utils import StableDiffusion + guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range) + + if 'IF' in opt.guidance: + from guidance.if_utils import IF + guidance['IF'] = IF(device, opt.vram_O, opt.t_range) + + if 'zero123' in opt.guidance: + from guidance.zero123_utils import Zero123 + guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt) + + if 'clip' in opt.guidance: + from guidance.clip_utils import CLIP + guidance['clip'] = CLIP(device) + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, + optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True) + + trainer.default_view_data = train_loader._data.get_default_view_data() + + + + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader(batch_size=1) + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + + # # test output + # trainer.test(test_loader,save_path = os.path.join(opt.workspace, 'initiation')) + + # TO BE DEBUGGED + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, test_loader, max_epoch) + + if opt.save_mesh: + trainer.save_mesh() diff --git a/stable-dreamfusion-3DPortrait/main_3DPortraitGAN_cam.py b/stable-dreamfusion-3DPortrait/main_3DPortraitGAN_cam.py new file mode 100644 index 0000000..e5d0671 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/main_3DPortraitGAN_cam.py @@ -0,0 +1,480 @@ +import os + +import torch +import argparse +import pandas as pd +import sys + +from nerf.provider_3DPortraitGAN import NeRFDataset +from nerf.utils import * + + +if __name__ == '__main__': + # See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre + class LoadFromFile (argparse.Action): + def __call__ (self, parser, namespace, values, option_string = None): + with values as f: + # parse arguments in the file and store them in the target namespace + parser.parse_args(f.read().split(), namespace) + + parser = argparse.ArgumentParser() + parser.add_argument('--file', type=open, action=LoadFromFile, help="specify a file filled with more arguments") + parser.add_argument('--text', default=None, help="text prompt") + parser.add_argument('--negative', default='', type=str, help="negative text prompt") + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray") + parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views") + parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs") + parser.add_argument('--test_interval', type=int, default=10, help="test on the test set every interval epochs") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', default=None) + + parser.add_argument('--image', default=None, help="image prompt") + parser.add_argument('--image_config', default=None, help="image config csv") + + parser.add_argument('--known_view_interval', type=int, default=4, help="train default view with RGB loss every & iters, only valid if --image is not None.") + + parser.add_argument('--IF', action='store_true', help="experimental: use DeepFloyd IF as the guidance model for nerf stage") + + parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model') + parser.add_argument('--guidance_scale', type=float, default=100, help="diffusion model classifier-free guidance scale") + + parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture") + parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh") + parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation") + + parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning") + parser.add_argument('--tet_grid_size', type=int, default=128, help="tet grid size") + parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet") + parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry") + + ## Perp-Neg options + parser.add_argument('--perpneg', action='store_true', help="use perp_neg") + parser.add_argument('--negative_w', type=float, default=-2, help="The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt") + parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt") + parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt") + + + ## Trigrid options + parser.add_argument('--trigrid_path', type=str, default='', help="path to trigrid") + parser.add_argument('--trigrid_decoder_ckpt', type=str, default='', help="path to trigrid decoder ckpt") + parser.add_argument('--train_decoder', action='store_true', help="train trigrid decoder") + parser.add_argument('--learnable_bg', action='store_true', help="Learnable background") + parser.add_argument('--trigrid_lr_ratio', type=float, nargs='+', default=[100,100,100,100,100,100,100,40], help="stable diffusion time steps range") + parser.add_argument('--scheduler_annealing', action='store_true', help="annealing scheduler") + parser.add_argument('--use_body_pose', action='store_true', help="use_body_pose") + + ### training options + parser.add_argument('--iters', type=int, default=10000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") + parser.add_argument('--ckpt', type=str, default='latest', help="possible options are ['latest', 'scratch', 'best', 'latest_model']") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=32, help="num steps up-sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") + parser.add_argument('--latent_iter_ratio', type=float, default=0.2, help="training iters that only use albedo shading") + parser.add_argument('--albedo_iter_ratio', type=float, default=0, help="training iters that only use albedo shading") + parser.add_argument('--min_ambient_ratio', type=float, default=0.1, help="minimum ambient ratio to use in lambertian shading") + parser.add_argument('--textureless_ratio', type=float, default=0.2, help="ratio of textureless shading") + parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses") + parser.add_argument('--jitter_center', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's center (camera location)") + parser.add_argument('--jitter_target', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')") + parser.add_argument('--jitter_up', type=float, default=0.02, help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')") + parser.add_argument('--uniform_sphere_rate', type=float, default=0, help="likelihood of sampling camera location uniformly on the sphere surface area") + parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it") + parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it") + # model options + parser.add_argument('--bg_radius', type=float, default=3.0, help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], help="density activation function") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--blob_density', type=float, default=5, help="max (center) density for the density blob") + parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob") + # network backbone + parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer") + parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") + # try this if CUDA OOM + parser.add_argument('--fp16', action='store_true', help="use float16 for training") + parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") + # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled. + parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training") + parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training") + parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering") + parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help="random camera noise added to rays_o and rays_d") + parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning") + parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF") + + ### dataset options + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)") + parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.01, help="minimum near distance for camera") + + parser.add_argument('--radius_range', type=float, nargs='*', default=[2.7, 2.7], + help="training camera radius range") + parser.add_argument('--theta_range', type=float, nargs='*', default=[60, 105], + help="training camera range along the polar angles (i.e. up and down). See advanced.md for details.") + parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], + help="training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.") + parser.add_argument('--fovy_range', type=float, nargs='*', default=[12.447863, 12.447864], + help="training camera fovy range") + + parser.add_argument('--default_radius', type=float, default=2.7, help="radius for the default view") + parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view") + parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view") + parser.add_argument('--default_fovy', type=float, default=12.447863, help="fovy for the default view") + + parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full") + parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, help="initial ratio of final range, used for progressive_view") + + parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level") + + parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region") + parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") + parser.add_argument('--t_range', type=float, nargs='+', default=[0.02, 0.98], help="stable diffusion time steps range") + parser.add_argument('--dont_override_stuff',action='store_true', help="Don't override t_range, etc.") + + + ### regularizations + parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy") + parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value") + parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation") + parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation") + parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale") + + parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness") + parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help="loss scale for mesh laplacian") + + parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS") + parser.add_argument('--lambda_rgb', type=float, default=1000, help="loss scale for RGB") + parser.add_argument('--lambda_mask', type=float, default=500, help="loss scale for mask (alpha)") + parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale for normal map") + parser.add_argument('--lambda_depth', type=float, default=10, help="loss scale for relative depth") + parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, help="loss scale for 2D normal image smoothness") + parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, help="loss scale for 3D normal image smoothness") + + ### debugging options + parser.add_argument('--save_guidance', action='store_true', help="save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!") + parser.add_argument('--save_guidance_interval', type=int, default=10, help="save guidance every X step") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=800, help="GUI width") + parser.add_argument('--H', type=int, default=800, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=20, help="default GUI camera fovy") + parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") + parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth") + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") + + parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123") + parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', help="ckpt for zero123") + parser.add_argument('--zero123_grad_scale', type=str, default='angle', help="whether to scale the gradients based on 'angle' or 'None'") + + parser.add_argument('--dataset_size_train', type=int, default=100, help="Length of train dataset i.e. # of iterations per epoch") + parser.add_argument('--dataset_size_valid', type=int, default=8, help="# of frames to render in the turntable video in validation") + parser.add_argument('--dataset_size_test', type=int, default=100, help="# of frames to render in the turntable video at test time") + + parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level") + parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level") + + opt = parser.parse_args() + if opt.O: + raise NotImplementedError + opt.fp16 = True + opt.cuda_ray = True + + elif opt.O2: + raise NotImplementedError + opt.fp16 = True + opt.progressive_level = True + + if opt.IF: + if 'SD' in opt.guidance: + opt.guidance.remove('SD') + opt.guidance.append('IF') + opt.latent_iter_ratio = 0 # must not do as_latent + + opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], [] + opt.default_zero123_w = 1 + + opt.exp_start_iter = opt.exp_start_iter or 0 + opt.exp_end_iter = opt.exp_end_iter or opt.iters + + # parameters for image-conditioned generation + if opt.image is not None or opt.image_config is not None: + + if opt.text is None: + # use zero123 guidance model when only providing image + opt.guidance = ['zero123'] + if not opt.dont_override_stuff: + opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov + opt.guidance_scale = 5 + opt.lambda_3d_normal_smooth = 10 + else: + # use stable-diffusion when providing both text and image + opt.guidance = ['SD', 'clip'] + + if not opt.dont_override_stuff: + opt.guidance_scale = 10 + opt.t_range = [0.2, 0.6] + opt.known_view_interval = 2 + opt.lambda_3d_normal_smooth = 20 + opt.bg_radius = -1 + + # smoothness + opt.lambda_entropy = 1 + opt.lambda_orient = 1 + + # latent warmup is not needed + opt.latent_iter_ratio = 0 + if not opt.dont_override_stuff: + opt.albedo_iter_ratio = 0 + + # make shape init more stable + opt.progressive_view = True + opt.progressive_level = True + + if opt.image is not None: + opt.images += [opt.image] + opt.ref_radii += [opt.default_radius] + opt.ref_polars += [opt.default_polar] + opt.ref_azimuths += [opt.default_azimuth] + opt.zero123_ws += [opt.default_zero123_w] + + if opt.image_config is not None: + # for multiview (zero123) + conf = pd.read_csv(opt.image_config, skipinitialspace=True) + opt.images += list(conf.image) + opt.ref_radii += list(conf.radius) + opt.ref_polars += list(conf.polar) + opt.ref_azimuths += list(conf.azimuth) + opt.zero123_ws += list(conf.zero123_weight) + if opt.image is None: + opt.default_radius = opt.ref_radii[0] + opt.default_polar = opt.ref_polars[0] + opt.default_azimuth = opt.ref_azimuths[0] + opt.default_zero123_w = opt.zero123_ws[0] + + # reset to None + if len(opt.images) == 0: + opt.images = None + + + if opt.learnable_bg: + assert opt.bg_radius> max(opt.radius_range), f"bg_radius must be larger than max(radius_range) = {max(opt.radius_range)}" + + + # default parameters for finetuning + if opt.dmtet: + + opt.h = int(opt.h * opt.dmtet_reso_scale) + opt.w = int(opt.w * opt.dmtet_reso_scale) + opt.known_view_scale = 1 + + if not opt.dont_override_stuff: + opt.t_range = [0.02, 0.50] # ref: magic3D + + if opt.images is not None: + + opt.lambda_normal = 0 + opt.lambda_depth = 0 + + if opt.text is not None and not opt.dont_override_stuff: + opt.t_range = [0.20, 0.50] + + # assume finetuning + opt.latent_iter_ratio = 0 + opt.albedo_iter_ratio = 0 + opt.progressive_view = False + # opt.progressive_level = False + os.makedirs(opt.workspace, exist_ok=True) + # record full range for progressive view expansion + if opt.progressive_view: + if not opt.dont_override_stuff: + # disable as they disturb progressive view + opt.jitter_pose = False + + opt.uniform_sphere_rate = 0 + # back up full range + opt.full_radius_range = opt.radius_range + opt.full_theta_range = opt.theta_range + opt.full_phi_range = opt.phi_range + opt.full_fovy_range = opt.fovy_range + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + from nerf.network_trigrid_heirarchy import NeRFNetwork + + # load plane from pkl + if os.path.isfile(opt.trigrid_path): + print(f'loading plane from {opt.trigrid_path}...') + import pickle + threeDRepresentation = {} + threeDRepresentation_shapes = {} + with open(opt.trigrid_path, 'rb') as f: + data = pickle.load(f) + ws = data['ws'] + threeDRepresentation['ws'] = ws + threeDRepresentation_shapes['ws'] = ws.shape + for key in data: + if 'trigrids' in key: + threeDRepresentation[key] = data[key]#.view(1, 3, -1, data[key].shape[-2], data[key].shape[-1]) + + threeDRepresentation_shapes[key] =threeDRepresentation[key].shape + + + print('save trigrid to workspace...') + shutil.copy(opt.trigrid_path, os.path.join(opt.workspace, 'trigrid.pkl')) + + + + model = NeRFNetwork( + opt=opt, + device=device, + trigrid_shapes=threeDRepresentation_shapes + ) + + # load + if os.path.isfile(opt.trigrid_decoder_ckpt): + print(f'loading trigrid_renderer from {opt.trigrid_decoder_ckpt}...') + ckpt = torch.load(opt.trigrid_decoder_ckpt, map_location=lambda storage, loc: storage) + # ckpt = {'params': params, 'state_dict': ckpt} + state_dict = ckpt['state_dict'] + params = ckpt['params'] + print(f'loading params: {params}') + + model.model.load_state_dict(state_dict, strict=True) + # + model.model.rendering_kwargs['depth_resolution'] = 48 + model.model.rendering_kwargs['depth_resolution_importance'] = 48 + # + model.model.rendering_kwargs['ray_start'] = 2.35 + # + # load plane from pkl + if os.path.isfile(opt.trigrid_path): + model.load_state_dict( + threeDRepresentation, strict=False + ) + + print(opt) + + if opt.seed is not None: + seed_everything(int(opt.seed)) + + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # + + + + + print(model) + + if opt.six_views: + raise NotImplementedError(f'--six_views {opt.six_views} is not implemented!') + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + test_loader = NeRFDataset(opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1) + trainer.test(test_loader, write_video=False) + + if opt.save_mesh: + trainer.save_mesh() + + elif opt.test: + raise NotImplementedError(f'--test {opt.test} is not implemented!') + guidance = None # no need to load guidance model at test + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + trainer.test(test_loader) + + if opt.save_mesh: + trainer.save_mesh() + + else: + + train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader() + # data = { + # 'H': self.H, + # 'W': self.W, + # 'rays_o': rays['rays_o'], + # 'rays_d': rays['rays_d'], + # 'dir': dirs, + # 'mvp': mvp, + # 'polar': delta_polar, + # 'azimuth': delta_azimuth, + # 'radius': delta_radius, + # } + if opt.optim == 'adan': + from optimizer import Adan + + # Adan usually requires a larger LR + optimizer = lambda model: Adan(model.get_params(5 * opt.lr,trigrid_lr_ratio = opt.trigrid_lr_ratio), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, + foreach=False) + else: # adam + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr,trigrid_lr_ratio = opt.trigrid_lr_ratio), betas=(0.9, 0.99), eps=1e-15) + + if opt.scheduler_annealing: + print('=============== use scheduler: annealing') + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, + lambda iter: 0.1 ** min(iter / opt.iters, 1)) + else: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed + # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + guidance = nn.ModuleDict() + + if 'SD' in opt.guidance: + from guidance.sd_utils import StableDiffusion + guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range) + + if 'IF' in opt.guidance: + from guidance.if_utils import IF + guidance['IF'] = IF(device, opt.vram_O, opt.t_range) + + if 'zero123' in opt.guidance: + from guidance.zero123_utils import Zero123 + guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt) + + if 'clip' in opt.guidance: + from guidance.clip_utils import CLIP + guidance['clip'] = CLIP(device) + + trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, + optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True) + + trainer.default_view_data = train_loader._data.get_default_view_data() + + + + + if opt.gui: + from nerf.gui import NeRFGUI + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader(batch_size=1) + test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1) + + # # test output + # trainer.test(test_loader,save_path = os.path.join(opt.workspace, 'initiation')) + + # TO BE DEBUGGED + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, test_loader, max_epoch) + + if opt.save_mesh: + trainer.save_mesh() diff --git a/stable-dreamfusion-3DPortrait/meshutils.py b/stable-dreamfusion-3DPortrait/meshutils.py new file mode 100644 index 0000000..4d1c53d --- /dev/null +++ b/stable-dreamfusion-3DPortrait/meshutils.py @@ -0,0 +1,117 @@ +import numpy as np +import pymeshlab as pml + +def poisson_mesh_reconstruction(points, normals=None): + # points/normals: [N, 3] np.ndarray + + import open3d as o3d + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + + # outlier removal + pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) + + # normals + if normals is None: + pcd.estimate_normals() + else: + pcd.normals = o3d.utility.Vector3dVector(normals[ind]) + + # visualize + o3d.visualization.draw_geometries([pcd], point_show_normal=False) + + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9) + vertices_to_remove = densities < np.quantile(densities, 0.1) + mesh.remove_vertices_by_mask(vertices_to_remove) + + # visualize + o3d.visualization.draw_geometries([mesh]) + + vertices = np.asarray(mesh.vertices) + triangles = np.asarray(mesh.triangles) + + print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}') + + return vertices, triangles + + +def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True): + # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + if backend == 'pyfqmr': + import pyfqmr + solver = pyfqmr.Simplify() + solver.setMesh(verts, faces) + solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) + verts, faces, normals = solver.getMesh() + else: + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + # ms.meshing_decimation_clustering(threshold=pml.Percentage(1)) + ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces + + +def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01): + # verts: [N, 3] + # faces: [N, 3] + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces + + if v_pct > 0: + ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal + + ms.meshing_remove_duplicate_faces() # faces defined by the same verts + ms.meshing_remove_null_faces() # faces with area == 0 + + if min_d > 0: + ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d)) + + if min_f > 0: + ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) + + if repair: + # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) + ms.meshing_repair_non_manifold_edges(method=0) + ms.meshing_repair_non_manifold_vertices(vertdispratio=0) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size)) + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/dnnlib/__init__.py b/stable-dreamfusion-3DPortrait/nerf/dnnlib/__init__.py new file mode 100644 index 0000000..dd91ed1 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/dnnlib/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/stable-dreamfusion-3DPortrait/nerf/dnnlib/util.py b/stable-dreamfusion-3DPortrait/nerf/dnnlib/util.py new file mode 100644 index 0000000..80b67c4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/dnnlib/util.py @@ -0,0 +1,493 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/stable-dreamfusion-3DPortrait/nerf/gui.py b/stable-dreamfusion-3DPortrait/nerf/gui.py new file mode 100644 index 0000000..65faa5c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/gui.py @@ -0,0 +1,485 @@ +import math +import torch +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from nerf.utils import * + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_matrix(np.eye(3)) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + self.near = 0.001 + self.far = 1000 + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] = self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + @property + def mvp(self): + focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) + projection = np.array([ + [2*focal/self.W, 0, 0, 0], + [0, -2*focal/self.H, 0, 0], + [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [0, 0, -1, 0] + ], dtype=np.float32) + + return projection @ np.linalg.inv(self.pose) # [4, 4] + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.deg2rad(-0.1 * dx) + rotvec_y = side * np.deg2rad(-0.1 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz]) + + +class NeRFGUI: + def __init__(self, opt, trainer, loader=None, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + self.debug = debug + self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.loader = loader + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.light_dir = np.array([opt.light_theta, opt.light_phi]) + self.ambient_ratio = 1.0 + self.mode = 'image' # choose from ['image', 'depth'] + self.shading = 'albedo' + + self.dynamic_resolution = True if not self.opt.dmtet else False + self.downscale = 1 + self.train_steps = 16 + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __del__(self): + dpg.destroy_context() + + + def train_step(self): + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.train_gui(self.loader, step=self.train_steps) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.step += self.train_steps + self.need_update = True + + dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # dynamic train steps + # max allowed train time per-frame is 500 ms + full_t = t / self.train_steps * 16 + train_steps = min(16, max(4, int(16 * 500 / full_t))) + if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + self.train_steps = train_steps + + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'].astype(np.float32) + else: + depth = outputs['depth'].astype(np.float32) + depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) + return np.expand_dims(depth, -1).repeat(3, -1) + + + def test_step(self): + + if self.need_update or self.spp < self.opt.max_spp: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # text prompt + if self.opt.text is not None: + dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text") + + if self.opt.negative != '': + dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text") + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + if not self.opt.test: + with dpg.group(horizontal=True): + dpg.add_text("Train time: ") + dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + if not self.opt.test: + with dpg.collapsing_header(label="Train", default_open=True): + with dpg.group(horizontal=True): + dpg.add_text("Train: ") + + def callback_train(sender, app_data): + if self.training: + self.training = False + dpg.configure_item("_button_train", label="start") + else: + self.training = True + dpg.configure_item("_button_train", label="stop") + + dpg.add_button(label="start", tag="_button_train", callback=callback_train) + dpg.bind_item_theme("_button_train", theme_button) + + def callback_reset(sender, app_data): + @torch.no_grad() + def weight_reset(m: nn.Module): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + self.trainer.model.apply(fn=weight_reset) + self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + self.need_update = True + + dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + dpg.bind_item_theme("_button_reset", theme_button) + + + with dpg.group(horizontal=True): + dpg.add_text("Checkpoint: ") + + def callback_save(sender, app_data): + self.trainer.save_checkpoint(full=True, best=False) + dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="save", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_text("", tag="_log_ckpt") + + # save mesh + with dpg.group(horizontal=True): + dpg.add_text("Marching Cubes: ") + + def callback_mesh(sender, app_data): + self.trainer.save_mesh() + dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + dpg.bind_item_theme("_button_mesh", theme_button) + + dpg.add_text("", tag="_log_mesh") + + with dpg.group(horizontal=True): + dpg.add_text("", tag="_log_train_log") + + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.opt.dt_gamma = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.opt.max_steps = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.trainer.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) + + # light dir + def callback_set_light_dir(sender, app_data, user_data): + self.light_dir[user_data] = app_data + self.need_update = True + + dpg.add_separator() + dpg.add_text("Plane Light Direction:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1) + + # ambient ratio + def callback_set_abm_ratio(sender, app_data): + self.ambient_ratio = app_data + self.need_update = True + + dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio) + + # shading mode + def callback_change_shading(sender, app_data): + self.shading = app_data + self.need_update = True + + dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading) + + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) + + # TODO: seems dearpygui doesn't support resizing texture... + # def callback_resize(sender, app_data): + # self.W = app_data[0] + # self.H = app_data[1] + # # how to reload texture ??? + + # dpg.set_viewport_resize_callback(callback_resize) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + if self.training: + self.train_step() + self.test_step() + dpg.render_dearpygui_frame() \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/network.py b/stable-dreamfusion-3DPortrait/nerf/network.py new file mode 100644 index 0000000..aceea26 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/network.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder + +from .utils import safe_normalize + +# TODO: not sure about the details... +class ResBlock(nn.Module): + def __init__(self, dim_in, dim_out, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias) + self.norm = nn.LayerNorm(self.dim_out) + self.activation = nn.SiLU(inplace=True) + + if self.dim_in != self.dim_out: + self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False) + else: + self.skip = None + + def forward(self, x): + # x: [B, C] + identity = x + + out = self.dense(x) + out = self.norm(out) + + if self.skip is not None: + identity = self.skip(identity) + + out += identity + out = self.activation(out) + + return out + +class BasicBlock(nn.Module): + def __init__(self, dim_in, dim_out, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + # x: [B, C] + + out = self.dense(x) + out = self.activation(out) + + return out + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + if l == 0: + net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias)) + elif l != num_layers - 1: + net.append(block(self.dim_hidden, self.dim_hidden, bias=bias)) + else: + net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias)) + + self.net = nn.ModuleList(net) + + + def forward(self, x): + + for l in range(self.num_layers): + x = self.net[l](x) + + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + num_layers=5, # 5 in paper + hidden_dim=64, # 128 in paper + num_layers_bg=2, # 3 in paper + hidden_dim_bg=32, # 64 in paper + encoding='frequency_torch', # pure pytorch + ): + + super().__init__(opt) + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=12) + self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True, block=ResBlock) + + self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus + + # background network + if self.opt.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4) + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + self.bg_net = None + + def common_forward(self, x): + # x: [N, 3], in [-bound, bound] + + # sigma + enc = self.encoder(x, bound=self.bound, max_level=self.max_level) + + h = self.sigma_net(enc) + + sigma = self.density_activation(h[..., 0] + self.density_blob(x)) + albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 + def finite_difference_normal(self, x, epsilon=1e-2): + # x: [N, 3] + dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + + normal = torch.stack([ + 0.5 * (dx_pos - dx_neg) / epsilon, + 0.5 * (dy_pos - dy_neg) / epsilon, + 0.5 * (dz_pos - dz_neg) / epsilon + ], dim=-1) + + return -normal + + def normal(self, x): + + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x.requires_grad_(True) + sigma, albedo = self.common_forward(x) + # query gradient + normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + + # normal = self.finite_difference_normal(x) + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + + return normal + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + + if shading == 'albedo': + # no need to query normal + sigma, color = self.common_forward(x) + normal = None + + else: + # query normal + + # sigma, albedo = self.common_forward(x) + # normal = self.normal(x) + + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x.requires_grad_(True) + sigma, albedo = self.common_forward(x) + # query gradient + normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + + # lambertian shading + lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] + + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + + return sigma, color, normal + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + sigma, albedo = self.common_forward(x) + + return { + 'sigma': sigma, + 'albedo': albedo, + } + + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + # {'params': self.encoder.parameters(), 'lr': lr * 10}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + ] + + if self.opt.bg_radius > 0: + # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + if self.opt.dmtet and not self.opt.lock_geo: + params.append({'params': self.sdf, 'lr': lr}) + params.append({'params': self.deform, 'lr': lr}) + + return params \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/network_grid.py b/stable-dreamfusion-3DPortrait/nerf/network_grid.py new file mode 100644 index 0000000..c308f3d --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/network_grid.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp, biased_softplus +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder + +from .utils import safe_normalize + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + num_layers=3, + hidden_dim=64, + num_layers_bg=2, + hidden_dim_bg=32, + ): + + super().__init__(opt) + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep') + + self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) + # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True) + + self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus + + # background network + if self.opt.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # use a very simple network to avoid it learning the prompt... + self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + self.bg_net = None + + def common_forward(self, x): + + # sigma + enc = self.encoder(x, bound=self.bound, max_level=self.max_level) + + h = self.sigma_net(enc) + + sigma = self.density_activation(h[..., 0] + self.density_blob(x)) + albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 + def finite_difference_normal(self, x, epsilon=1e-2): + # x: [N, 3] + dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + + normal = torch.stack([ + 0.5 * (dx_pos - dx_neg) / epsilon, + 0.5 * (dy_pos - dy_neg) / epsilon, + 0.5 * (dz_pos - dz_neg) / epsilon + ], dim=-1) + + return -normal + + def normal(self, x): + normal = self.finite_difference_normal(x) + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + return normal + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + + sigma, albedo = self.common_forward(x) + + if shading == 'albedo': + normal = None + color = albedo + + else: # lambertian shading + + # normal = self.normal_net(enc) + normal = self.normal(x) + + lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] + + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + + return sigma, color, normal + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + sigma, albedo = self.common_forward(x) + + return { + 'sigma': sigma, + 'albedo': albedo, + } + + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr * 10}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + # {'params': self.normal_net.parameters(), 'lr': lr}, + ] + + if self.opt.bg_radius > 0: + # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + if self.opt.dmtet and not self.opt.lock_geo: + params.append({'params': self.sdf, 'lr': lr}) + params.append({'params': self.deform, 'lr': lr}) + + return params \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/network_grid_taichi.py b/stable-dreamfusion-3DPortrait/nerf/network_grid_taichi.py new file mode 100644 index 0000000..8fa2efd --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/network_grid_taichi.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder + +from .utils import safe_normalize + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + num_layers=2, + hidden_dim=32, + num_layers_bg=2, + hidden_dim_bg=16, + ): + + super().__init__(opt) + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep') + + self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) + # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True) + + self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus + + # background network + if self.opt.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + # use a very simple network to avoid it learning the prompt... + self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + self.bg_net = None + + def common_forward(self, x): + + # sigma + enc = self.encoder(x, bound=self.bound) + + h = self.sigma_net(enc) + + sigma = self.density_activation(h[..., 0] + self.density_blob(x)) + albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 + def finite_difference_normal(self, x, epsilon=1e-2): + # x: [N, 3] + dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + + normal = torch.stack([ + 0.5 * (dx_pos - dx_neg) / epsilon, + 0.5 * (dy_pos - dy_neg) / epsilon, + 0.5 * (dz_pos - dz_neg) / epsilon + ], dim=-1) + + return -normal + + def normal(self, x): + normal = self.finite_difference_normal(x) + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + return normal + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + + sigma, albedo = self.common_forward(x) + + if shading == 'albedo': + normal = None + color = albedo + + else: # lambertian shading + # normal = self.normal_net(enc) + normal = self.normal(x) + + lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] + + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + + return sigma, color, normal + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + sigma, albedo = self.common_forward(x) + + return { + 'sigma': sigma, + 'albedo': albedo, + } + + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr * 10}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + # {'params': self.normal_net.parameters(), 'lr': lr}, + ] + + if self.opt.bg_radius > 0: + # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + if self.opt.dmtet and not self.opt.lock_geo: + params.append({'params': self.sdf, 'lr': lr}) + params.append({'params': self.deform, 'lr': lr}) + + return params \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/network_grid_tcnn.py b/stable-dreamfusion-3DPortrait/nerf/network_grid_tcnn.py new file mode 100644 index 0000000..e270789 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/network_grid_tcnn.py @@ -0,0 +1,178 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp, biased_softplus +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder + +from .utils import safe_normalize + +import tinycudann as tcnn + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + num_layers=3, + hidden_dim=64, + num_layers_bg=2, + hidden_dim_bg=32, + ): + + super().__init__(opt) + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder = tcnn.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "interpolation": "Smoothstep", + "per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)), + }, + dtype=torch.float32, # ENHANCE: default float16 seems unstable... + ) + self.in_dim = self.encoder.n_output_dims + # use torch MLP, as tcnn MLP doesn't impl second-order derivative + self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) + + self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus + + # background network + if self.opt.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # use a very simple network to avoid it learning the prompt... + self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + self.bg_net = None + + def common_forward(self, x): + + # sigma + enc = self.encoder((x + self.bound) / (2 * self.bound)).float() + h = self.sigma_net(enc) + + sigma = self.density_activation(h[..., 0] + self.density_blob(x)) + albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + def normal(self, x): + + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x.requires_grad_(True) + sigma, albedo = self.common_forward(x) + # query gradient + normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + + # normal = self.finite_difference_normal(x) + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + + return normal + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + + + if shading == 'albedo': + sigma, albedo = self.common_forward(x) + normal = None + color = albedo + + else: # lambertian shading + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x.requires_grad_(True) + sigma, albedo = self.common_forward(x) + normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + + lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] + + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + + return sigma, color, normal + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + sigma, albedo = self.common_forward(x) + + return { + 'sigma': sigma, + 'albedo': albedo, + } + + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr * 10}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + ] + + if self.opt.bg_radius > 0: + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + if self.opt.dmtet and not self.opt.lock_geo: + params.append({'params': self.sdf, 'lr': lr}) + params.append({'params': self.deform, 'lr': lr}) + + return params \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/network_trigrid_heirarchy.py b/stable-dreamfusion-3DPortrait/nerf/network_trigrid_heirarchy.py new file mode 100644 index 0000000..b582fe8 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/network_trigrid_heirarchy.py @@ -0,0 +1,413 @@ +from .trigrid_rendering.neural_render import NeuralRender as TrigridNeRFRenderer +from .renderer import NeRFRenderer + +import torch +import torch.nn as nn +import torch.nn.functional as F +from encoding import get_encoder + +from .utils import safe_normalize + +import numpy as np +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, + self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + device, + trigrid_shapes, + num_layers_bg=2, + hidden_dim_bg=32, + ): + super().__init__(opt) + + + self.triplane_names = {} + for k in trigrid_shapes: + self.register_parameter(k, torch.nn.Parameter(torch.randn(trigrid_shapes[k]))) + + if k.startswith('trigrid'): + res = int(k.split('_')[1]) + self.triplane_names[res] = k + + # sort the triplane names by resolution + self.triplane_names = {k: self.triplane_names[k] for k in sorted(self.triplane_names.keys())} + + params = {'z_dim': 512, 'c_dim': 25, 'w_dim': 512, 'img_resolution': 512, 'img_channels': 3, + 'rendering_kwargs': {'image_resolution': 512, 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'superresolution_module': 'training.superresolution.SuperresolutionHybrid8XDC', + 'c_gen_conditioning_zero': False, 'gpc_reg_prob': 0.7, + 'decoder_activation': 'none', 'use_torgb_raw': True, + 'use_background': True, 'triplane_depth': 3, 'c_scale': 1.0, + 'superresolution_noise_mode': 'none', 'density_reg': 0.0, + 'density_reg_p_dist': 0.004, 'reg_type': 'l1', 'decoder_lr_mul': 1.0, + 'sr_antialias': True, 'radius_scale': 0.7, + 'depth_resolution': 48, 'depth_resolution_importance': 48, + 'ray_start': 2.3850000000000002, 'ray_end': 3.12, + 'box_warp': 0.7, 'density_noise': 0.0}, + 'batch_size': 1, 'thickness': 0.25,"apply_deformation": self.opt.use_body_pose, + } + self.model = TrigridNeRFRenderer(**params).to(device) + # self.trigrid_4 = torch.nn.Parameter(torch.randn([1, 3, 16 * 3 , 4, 4])) + # self.trigrid_8 = torch.nn.Parameter(torch.randn([1, 3, 16 * 3 , 8, 8])) + # self.trigrid_16=torch.nn.Parameter(torch.randn([1, 3, 16 * 3 , 16, 16])) + # self.trigrid_32 =torch.nn.Parameter(torch.randn([1, 3, 16* 3, 32, 32])) + # self.trigrid_64 =torch.nn.Parameter(torch.randn([1, 3, 16* 3, 64, 64])) + # self.trigrid_128 =torch.nn.Parameter(torch.randn([1,3, 16* 3, 128, 128])) + # self.trigrid_256 = torch.nn.Parameter(torch.randn([1, 3, 32*3, 256, 256])) + + # self.ws = torch.nn.Parameter(torch.randn([1, 14, 512])) + + # background network + if self.opt.bg_radius > 0: + + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # use a very simple network to avoid it learning the prompt... + self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + assert self.opt.learnable_bg == False + self.bg_net = None + + + + self.train_decoder = opt.train_decoder + + def common_forward(self, x): + + # # sigma + # enc = self.encoder(x, bound=self.bound, max_level=self.max_level) + # + # h = self.sigma_net(enc) + # + # sigma = self.density_activation(h[..., 0] + self.density_blob(x)) + # albedo = torch.sigmoid(h[..., 1:]) + + + # return sigma, albedo + + # TODO + raise NotImplementedError + + # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 + def finite_difference_normal(self, x, epsilon=1e-2): + # x: [N, 3] + # dx_pos, _ = self.common_forward( + # (x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + # dx_neg, _ = self.common_forward( + # (x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + # dy_pos, _ = self.common_forward( + # (x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + # dy_neg, _ = self.common_forward( + # (x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + # dz_pos, _ = self.common_forward( + # (x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + # dz_neg, _ = self.common_forward( + # (x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + # + # normal = torch.stack([ + # 0.5 * (dx_pos - dx_neg) / epsilon, + # 0.5 * (dy_pos - dy_neg) / epsilon, + # 0.5 * (dz_pos - dz_neg) / epsilon + # ], dim=-1) + + # return -normal + + # TODO + raise NotImplementedError + + def normal(self, x): + # normal = self.finite_difference_normal(x) + # normal = safe_normalize(normal) + # normal = torch.nan_to_num(normal) + + # return normal + + # TODO + raise NotImplementedError + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + ''' + x: [N, 3], in [-bound, bound] + d: [N, 3], view direction, nomalized in [-1, 1] + l: [3], plane light direction, nomalized in [-1, 1] + ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + ''' + + # sigma, albedo = self.common_forward(x) + # + # if shading == 'albedo': + # normal = None + # color = albedo + # + # else: # lambertian shading + # + # # normal = self.normal_net(enc) + # normal = self.normal(x) + # + # lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] + # + # if shading == 'textureless': + # color = lambertian.unsqueeze(-1).repeat(1, 3) + # elif shading == 'normal': + # color = (normal + 1) / 2 + # else: # 'lambertian' + # color = albedo * lambertian.unsqueeze(-1) + + # return sigma, color, normal + + # TODO + raise NotImplementedError + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + # sigma, albedo = self.common_forward(x) + + + + # return { + # 'sigma': sigma, + # 'albedo': albedo, + # } + + # TODO + raise NotImplementedError + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr,trigrid_lr_ratio): + params =[] + assert len(trigrid_lr_ratio) == len(self.triplane_names) + resolutions = list(self.triplane_names.keys()) + for i in range(len(trigrid_lr_ratio)): + print(f'{self.triplane_names[resolutions[i]]} lr: {lr*trigrid_lr_ratio[i]}') + + params.append({'params': getattr(self, self.triplane_names[resolutions[i]]), 'lr': lr*trigrid_lr_ratio[i]}) + + # params.append({'params': self.ws, 'lr': lr*0.1}) + + if self.train_decoder: + params.append({'params': self.model.parameters(lr), 'lr': lr}) + + + return params + + @torch.no_grad() + def export_mesh(self, path, resolution=None, decimate_target=-1, S=128): + raise NotImplementedError + + + + + def render(self, rays_o, rays_d, poses, h, w, staged=False, max_ray_batch=4096, bg_color = None,bg_rays_o=None,bg_rays_d=None, **kwargs): + cam2world_pose = poses.clone() + cam2world_pose[:, :3, :3] = cam2world_pose[:, :3, :3] * -1 + cam2world_pose[:, 0, 1] *= -1 + cam2world_pose[:, 0, 2] *= -1 + cam2world_pose[:, 1, 0] *= -1 + cam2world_pose[:, 2, 0] *= -1 + cam2world_pose[:, 0, 3] *= -1 + + intrinsics = [6.510416666666667, + 0.0, + 0.5, + 0.0, + 6.510416666666667, + 0.5, + 0.0, + 0.0, + 1.0] + intrinsics = torch.tensor(intrinsics).to(cam2world_pose.device) + camera_params = torch.cat([cam2world_pose.reshape(1, 16), intrinsics.reshape(1, 9)], 1) + + # rays_o, rays_d: [B, N, 3] + # return: pred_rgb: [B, N, 3] + #B, N = rays_o.shape[:2] + H = h + W = w + + if self.opt.learnable_bg: + assert bg_color is None, 'bg_color should be None when learnable_bg is True' + bg_color = self.background(rays_d.contiguous().view(-1, 3)) # [BHW, 3] + # from [BHW, 3] to [B, H, W, 3] + bg_color = bg_color.view(-1, H, W, 3).clamp(0, 1.0) + + + device = rays_o.device + N, M, _ = rays_o.shape + + planes = {} + for res in self.triplane_names: + planes[res] = getattr(self, self.triplane_names[res]) + + if self.opt.use_body_pose: + # apply_def=apply_def, ws=None, pose_params=pose_params + pose_params = self.model.sample_pose_params(camera_params) + apply_def = True + else: + pose_params = None + apply_def = False + + + + if staged: + + + depth = torch.empty((N,M,1), device=device) + image = torch.empty((N,M, 32), device=device) + weights_sum = torch.empty((N,M,1), device=device) + + for b in range(N): + head = 0 + while head < M: + tail = min(head + max_ray_batch, M) + + render_output = self.model.renderer(planes, self.model.decoder, rays_o[b:b + 1, head:tail], + rays_d[b:b + 1, head:tail], self.model.rendering_kwargs, apply_def=apply_def, ws=None, pose_params=pose_params) # channels last + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] # max_ray_batch,32 + depth_samples = render_output['depth_final'] # 1, max_ray_batch + weights_samples = render_output['weights'] # 1, max_ray_batch, depth + + weights_sum_samples = weights_samples.sum(2) # 1, max_ray_batch,1 + + + + depth[b:b + 1, head:tail] = depth_samples + weights_sum[b:b + 1, head:tail] = weights_sum_samples + image[b:b + 1, head:tail] = feature_samples + head += max_ray_batch + + feature_samples = image + depth_samples = depth + weights_sum_samples = weights_sum + + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_sum_samples = weights_sum_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run superresolution to get final image + if self.model.decoder.activation == "sigmoid": + assert self.model.decoder.out_channels == 3 + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + + # Generate Raw image + if self.model.torgb: + rgb_image = self.model.torgb(feature_image, self.ws[:, -1], fused_modconv=False) + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + weights_sum_samples = weights_sum_samples * (1 + 2 * 0.001) - 0.001 + + # from [B,C,H,W] to [B, H, W, C] + rgb_image = (rgb_image.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1.0).contiguous() + depth_image = depth_image.permute(0, 2, 3, 1).contiguous().squeeze(-1) # [B, H, W] + weights_sum_samples = weights_sum_samples.permute(0, 2, 3, 1).contiguous().squeeze(-1) # [B, H, W] + + + + if bg_color is not None and self.opt.learnable_bg: + assert bg_color.shape == rgb_image.shape, f'bg_color.shape {bg_color.shape} should be equal to rgb_image.shape {rgb_image.shape}' + rgb_image = rgb_image + (1 - weights_sum_samples).unsqueeze(-1) * bg_color + + + return {'image': rgb_image, 'depth': depth_image, + "weights_sum": weights_sum_samples} + + + else: + + + # Create triplanes by running StyleGAN backbone + + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + + #self.trigrid.register_hook(lambda grad: print(grad,grad.abs().sum(), grad.abs().max(),grad.abs().min())) + # Perform volume rendering + render_output = self.model.renderer(planes, self.model.decoder, rays_o, + rays_d, self.model.rendering_kwargs, apply_def=apply_def, ws=None, pose_params=pose_params) # channels last + + + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] + depth_samples = render_output['depth_final'] + weights_samples = render_output['weights'] # 1, max_ray_batch, depth,1 + weights_sum_samples = weights_samples.sum(2) # 1, max_ray_batch,1 + + # Reshape into 'raw' neural-rendered image + + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_sum_samples = weights_sum_samples.permute(0, 2, 1).reshape(N, 1, H, W) + depth = weights_samples.shape[-2] + weights_samples = weights_samples.squeeze(-1).permute(0, 2, 1).reshape(N,depth, H, W) + + # Run superresolution to get final image + if self.model.decoder.activation == "sigmoid": + assert self.model.decoder.out_channels == 3 + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + feature_image.register_hook(lambda x:print(f'in sigmoid, feature_image.grad = {x}')) + + # Generate Raw image + if self.model.torgb: + rgb_image = self.model.torgb(feature_image, self.ws[:, -1], fused_modconv=False) + + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + weights_sum_samples = weights_sum_samples * (1 + 2 * 0.001) - 0.001 + + + # from [B,C,H,W] to [B, H, W, C] + rgb_image = (rgb_image.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1.0).contiguous() + depth_image = depth_image.permute(0, 2, 3, 1).contiguous().squeeze(-1) + weights_sum_samples = weights_sum_samples.permute(0, 2, 3, 1).contiguous().squeeze(-1) + weights_samples = weights_samples.permute(0, 2, 3, 1).contiguous() # B, H, W, D + + + if bg_color is not None and self.opt.learnable_bg: + assert bg_color.shape == rgb_image.shape, f'bg_color.shape {bg_color.shape} should be equal to rgb_image.shape {rgb_image.shape}' + rgb_image = rgb_image + (1 - weights_sum_samples).unsqueeze(-1) * bg_color + + return {'image': rgb_image, 'depth': depth_image,"weights":weights_samples, "weights_sum": weights_sum_samples} + + + #results = self.run(rays_o, rays_d, **kwargs) + + + #return results \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/provider.py b/stable-dreamfusion-3DPortrait/nerf/provider.py new file mode 100644 index 0000000..764efcb --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/provider.py @@ -0,0 +1,357 @@ +import os +import cv2 +import glob +import json +import tqdm +import random +import numpy as np +from scipy.spatial.transform import Slerp, Rotation + +import trimesh + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from .utils import get_rays, safe_normalize + +DIR_COLORS = np.array([ + [255, 0, 0, 255], # front + [0, 255, 0, 255], # side + [0, 0, 255, 255], # back + [255, 255, 0, 255], # side + [255, 0, 255, 255], # overhead + [0, 255, 255, 255], # bottom +], dtype=np.uint8) + +def visualize_poses(poses, dirs, size=0.1): + # poses: [B, 4, 4], dirs: [B] + + axes = trimesh.creation.axis(axis_length=4) + sphere = trimesh.creation.icosphere(radius=1) + objects = [axes, sphere] + + for pose, dir in zip(poses, dirs): + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) + segs = trimesh.load_path(segs) + + # different color for different dirs + segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0) + + objects.append(segs) + + trimesh.Scene(objects).show() + +def get_view_direction(thetas, phis, overhead, front): + # phis: [B,]; thetas: [B,] + # front = 0 [-front/2, front/2) + # side (cam left) = 1 [front/2, 180-front/2) + # back = 2 [180-front/2, 180+front/2) + # side (cam right) = 3 [180+front/2, 360-front/2) + # top = 4 [0, overhead] + # bottom = 5 [180-overhead, 180] + res = torch.zeros(thetas.shape[0], dtype=torch.long) + # first determine by phis + phis = phis % (2 * np.pi) + res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0 + res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1 + res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2 + res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3 + # override by thetas + res[thetas <= overhead] = 4 + res[thetas >= (np.pi - overhead)] = 5 + return res + + +def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius: camera radius + theta_range: [min, max], should be in [0, pi] + phi_range: [min, max], should be in [0, 2 * pi] + Return: + poses: [size, 4, 4] + ''' + + theta_range = np.array(theta_range) / 180 * np.pi + phi_range = np.array(phi_range) / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] + + if random.random() < uniform_sphere_rate: + unit_centers = F.normalize( + torch.stack([ + torch.randn(size, device=device), + torch.abs(torch.randn(size, device=device)), + torch.randn(size, device=device), + ], dim=-1), p=2, dim=1 + ) + thetas = torch.acos(unit_centers[:,1]) + phis = torch.atan2(unit_centers[:,0], unit_centers[:,2]) + phis[phis < 0] += 2 * np.pi + centers = unit_centers * radius.unsqueeze(-1) + else: + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + phis[phis < 0] += 2 * np.pi + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] + + targets = 0 + + # jitters + if opt.jitter_pose: + jit_center = opt.jitter_center # 0.015 # was 0.2 + jit_target = opt.jitter_target + centers += torch.rand_like(centers) * jit_center - jit_center/2.0 + targets += torch.randn_like(centers) * jit_target + + # lookat + forward_vector = safe_normalize(centers - targets) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + + if opt.jitter_pose: + up_noise = torch.randn_like(up_vector) * opt.jitter_up + else: + up_noise = 0 + + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + if return_dirs: + dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) + else: + dirs = None + + # back to degree + thetas = thetas / np.pi * 180 + phis = phis / np.pi * 180 + + return poses, dirs, thetas, phis, radius + + +def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60): + + theta = theta / 180 * np.pi + phi = phi / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + centers = torch.stack([ + radius * torch.sin(theta) * torch.sin(phi), + radius * torch.cos(theta), + radius * torch.sin(theta) * torch.cos(phi), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = safe_normalize(centers) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + if return_dirs: + dirs = get_view_direction(theta, phi, angle_overhead, angle_front) + else: + dirs = None + + return poses, dirs + + +class NeRFDataset: + def __init__(self, opt, device, type='train', H=256, W=256, size=100, teacher_H = None, teacher_W = None): + super().__init__() + + self.opt = opt + self.device = device + self.type = type # train, val, test + + self.H = H + self.W = W + self.size = size + + self.teacher_H = teacher_H + self.teacher_W = teacher_W + + self.training = self.type in ['train', 'all'] + + self.cx = self.H / 2 + self.cy = self.W / 2 + + self.near = self.opt.min_near + self.far = 1000 # infinite + + # [debug] visualize poses + # poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1) + # visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy()) + + def get_default_view_data(self): + + H = int(self.opt.known_view_scale * self.H) + W = int(self.opt.known_view_scale * self.W) + cx = H / 2 + cy = W / 2 + + radii = torch.FloatTensor(self.opt.ref_radii).to(self.device) + thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device) + phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device) + poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front) + fov = self.opt.default_fovy + focal = H / (2 * np.tan(np.deg2rad(fov) / 2)) + intrinsics = np.array([focal, focal, cx, cy]) + + projection = torch.tensor([ + [2*focal/W, 0, 0, 0], + [0, -2*focal/H, 0, 0], + [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [0, 0, -1, 0] + ], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1) + + mvp = projection @ torch.inverse(poses) # [B, 4, 4] + + # sample a low-resolution but full image + rays = get_rays(poses, intrinsics, H, W, -1) + + if self.teacher_W is not None: + teacher_H =int(self.opt.known_view_scale * self.teacher_H) + teacher_W = int(self.opt.known_view_scale * self.teacher_W) + + teacher_cx = teacher_H / 2 + teacher_cy = teacher_W / 2 + + teacher_focal = teacher_H / (2 * np.tan(np.deg2rad(fov) / 2)) + teacher_intrinsics = np.array([teacher_focal, teacher_focal, teacher_cx, teacher_cy]) + + teacher_rays = get_rays(poses, teacher_intrinsics, teacher_H, teacher_W, -1) + + data = { + 'H': H, + 'W': W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'dir': dirs, + 'mvp': mvp, + 'polar': self.opt.ref_polars, + 'azimuth': self.opt.ref_azimuths, + 'radius': self.opt.ref_radii, + + 'teacher_H': self.teacher_W if self.teacher_W is not None else None, + 'teacher_W': self.teacher_W if self.teacher_W is not None else None, + 'teacher_rays_o': teacher_rays['rays_o'] if self.teacher_W is not None else None, + 'teacher_rays_d': teacher_rays['rays_d'] if self.teacher_W is not None else None, + } + + return data + + def collate(self, index): + + B = len(index) + + if self.training: + # random pose on the fly + poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate) + + # random focal + fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0] + + elif self.type == 'six_views': + # six views + thetas_six = [90, 90, 90, 90, 1e-3, 179.999] + phis_six = [ 0, 90, 180, -90, 0, 0] + thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device) + phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device) + radius = torch.FloatTensor([self.opt.default_radius]).to(self.device) + poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front) + + # fixed focal + fov = self.opt.default_fovy + + else: + # circle pose + thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device) + phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device) + radius = torch.FloatTensor([self.opt.default_radius]).to(self.device) + poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front) + + # fixed focal + fov = self.opt.default_fovy + + focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2)) + intrinsics = np.array([focal, focal, self.cx, self.cy]) + + projection = torch.tensor([ + [2*focal/self.W, 0, 0, 0], + [0, -2*focal/self.H, 0, 0], + [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [0, 0, -1, 0] + ], dtype=torch.float32, device=self.device).unsqueeze(0) + + mvp = projection @ torch.inverse(poses) # [1, 4, 4] + + # sample a low-resolution but full image + rays = get_rays(poses, intrinsics, self.H, self.W, -1) + + # delta polar/azimuth/radius to default view + delta_polar = thetas - self.opt.default_polar + delta_azimuth = phis - self.opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - self.opt.default_radius + + if self.teacher_W is not None: + teacher_H = self.teacher_H + teacher_W = self.teacher_W + + teacher_cx = teacher_H / 2 + teacher_cy = teacher_W / 2 + + teacher_focal = teacher_H / (2 * np.tan(np.deg2rad(fov) / 2)) + teacher_intrinsics = np.array([teacher_focal, teacher_focal, teacher_cx, teacher_cy]) + + teacher_rays = get_rays(poses, teacher_intrinsics, teacher_H, teacher_W, -1) + + data = { + 'H': self.H, + 'W': self.W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'dir': dirs, + 'mvp': mvp, + 'polar': delta_polar, + 'azimuth': delta_azimuth, + 'radius': delta_radius, + + 'teacher_H': teacher_H if self.teacher_W is not None else None, + 'teacher_W': teacher_W if self.teacher_W is not None else None, + 'teacher_rays_o': teacher_rays['rays_o'] if self.teacher_W is not None else None, + 'teacher_rays_d': teacher_rays['rays_d'] if self.teacher_W is not None else None, + } + + return data + + def dataloader(self, batch_size=None): + batch_size = batch_size or self.opt.batch_size + loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self + return loader \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/provider_3DPortraitGAN.py b/stable-dreamfusion-3DPortrait/nerf/provider_3DPortraitGAN.py new file mode 100644 index 0000000..647c6af --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/provider_3DPortraitGAN.py @@ -0,0 +1,332 @@ +import os +import cv2 +import glob +import json +import tqdm +import random +import numpy as np +from scipy.spatial.transform import Slerp, Rotation + +import trimesh + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from .utils import get_rays, safe_normalize + +DIR_COLORS = np.array([ + [255, 0, 0, 255], # front + [0, 255, 0, 255], # side + [0, 0, 255, 255], # back + [255, 255, 0, 255], # side + [255, 0, 255, 255], # overhead + [0, 255, 255, 255], # bottom +], dtype=np.uint8) + +def visualize_poses(poses, dirs, size=0.1): + # poses: [B, 4, 4], dirs: [B] + + axes = trimesh.creation.axis(axis_length=4) + sphere = trimesh.creation.icosphere(radius=1) + objects = [axes, sphere] + + for pose, dir in zip(poses, dirs): + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2] + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) + segs = trimesh.load_path(segs) + + # different color for different dirs + segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0) + + objects.append(segs) + + trimesh.Scene(objects).show() + +def get_view_direction(thetas, phis, overhead, front): + # phis: [B,]; thetas: [B,] + # front = 0 [-front/2, front/2) + # side (cam left) = 1 [front/2, 180-front/2) + # back = 2 [180-front/2, 180+front/2) + # side (cam right) = 3 [180+front/2, 360-front/2) + # top = 4 [0, overhead] + # bottom = 5 [180-overhead, 180] + res = torch.zeros(thetas.shape[0], dtype=torch.long) + # first determine by phis + phis = phis % (2 * np.pi) + res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0 + res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1 + res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2 + res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3 + # override by thetas + res[thetas <= overhead] = 4 + res[thetas >= (np.pi - overhead)] = 5 + return res + + +def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5,pivot = None): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius: camera radius + theta_range: [min, max], should be in [0, pi] + phi_range: [min, max], should be in [0, 2 * pi] + Return: + poses: [size, 4, 4] + ''' + assert pivot is not None + + theta_range = np.array(theta_range) / 180 * np.pi + phi_range = np.array(phi_range) / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] + + if random.random() < uniform_sphere_rate: + unit_centers = F.normalize( + torch.stack([ + torch.randn(size, device=device), + torch.abs(torch.randn(size, device=device)), + torch.randn(size, device=device), + ], dim=-1), p=2, dim=1 + ) + thetas = torch.acos(unit_centers[:,1]) + phis = torch.atan2(unit_centers[:,0], unit_centers[:,2]) + phis[phis < 0] += 2 * np.pi + centers = unit_centers * radius.unsqueeze(-1) + else: + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + phis[phis < 0] += 2 * np.pi + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] + + targets = pivot + + # jitters + if opt.jitter_pose: + jit_center = opt.jitter_center # 0.015 # was 0.2 + jit_target = opt.jitter_target + centers += torch.rand_like(centers) * jit_center - jit_center/2.0 + targets += torch.randn_like(centers) * jit_target + + # lookat + forward_vector = safe_normalize(centers - targets) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + + if opt.jitter_pose: + up_noise = torch.randn_like(up_vector) * opt.jitter_up + else: + up_noise = 0 + + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + if return_dirs: + dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) + else: + dirs = None + + # back to degree + thetas = thetas / np.pi * 180 + phis = phis / np.pi * 180 + + return poses, dirs, thetas, phis, radius + + +def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60,pivot= None): + assert pivot is not None + + theta = theta / 180 * np.pi + phi = phi / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + centers = torch.stack([ + radius * torch.sin(theta) * torch.sin(phi), + radius * torch.cos(theta), + radius * torch.sin(theta) * torch.cos(phi), + ], dim=-1) # [B, 3] + + targets = pivot + # lookat + forward_vector = safe_normalize(centers - targets) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + if return_dirs: + dirs = get_view_direction(theta, phi, angle_overhead, angle_front) + else: + dirs = None + + return poses, dirs + + +class NeRFDataset: + def __init__(self, opt, device, type='train', H=256, W=256, size=100): + super().__init__() + + self.opt = opt + self.device = device + self.type = type # train, val, test + + self.H = H + self.W = W + self.size = size + + self.training = self.type in ['train', 'all'] + + self.cx = self.H / 2 + self.cy = self.W / 2 + + self.near = self.opt.min_near + self.far = 1000 # infinite + + self.cam_pivot = torch.tensor([0, 0.0649, 0], device=device).view(1, 3) + + # [debug] visualize poses + # poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1) + # visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy()) + + def get_default_view_data(self): + + H = int(self.opt.known_view_scale * self.H) + W = int(self.opt.known_view_scale * self.W) + cx = H / 2 + cy = W / 2 + + radii = torch.FloatTensor([self.opt.default_radius]).to(self.device) + thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device) + phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device) + + + poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, pivot = self.cam_pivot) + fov = self.opt.default_fovy + focal = H * float(1 / (np.tan(fov * 3.14159 / 360) * 1.414)) # H / (2 * np.tan(np.deg2rad(fov) / 2)) + + intrinsics = np.array([focal, focal, cx, cy]) + + projection = torch.tensor([ + [2*focal/W, 0, 0, 0], + [0, -2*focal/H, 0, 0], + [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [0, 0, -1, 0] + ], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1) + + mvp = projection @ torch.inverse(poses) # [B, 4, 4] + + # sample a low-resolution but full image + rays = get_rays(poses, intrinsics, H, W, -1) + + data = { + 'H': H, + 'W': W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'dir': dirs, + 'mvp': mvp, + 'polar': self.opt.ref_polars, + 'azimuth': self.opt.ref_azimuths, + 'radius': self.opt.ref_radii, + } + + return data + + def collate(self, index): + + B = len(index) + + if self.training: + # random pose on the fly + poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, + phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, + uniform_sphere_rate=self.opt.uniform_sphere_rate, pivot = self.cam_pivot) + + # random focal + fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0] + + elif self.type == 'six_views': + # six views + thetas_six = [90, 90, 90, 90, 1e-3, 179.999] + phis_six = [ 0, 90, 180, -90, 0, 0] + thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device) + phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device) + radius = torch.FloatTensor([self.opt.default_radius]).to(self.device) + poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, pivot = self.cam_pivot) + + # fixed focal + fov = self.opt.default_fovy + + else: + # circle pose + thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device) + phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device) + radius = torch.FloatTensor([self.opt.default_radius]).to(self.device) + + poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, pivot = self.cam_pivot) + + # fixed focal + fov = self.opt.default_fovy + + focal = self.H * float(1 / (np.tan(fov * 3.14159 / 360) * 1.414)) #self.H / (2 * np.tan(np.deg2rad(fov) / 2)) + + intrinsics = np.array([focal, focal, self.cx, self.cy]) + + projection = torch.tensor([ + [2*focal/self.W, 0, 0, 0], + [0, -2*focal/self.H, 0, 0], + [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [0, 0, -1, 0] + ], dtype=torch.float32, device=self.device).unsqueeze(0) + + mvp = projection @ torch.inverse(poses) # [1, 4, 4] + + # sample a low-resolution but full image + rays = get_rays(poses, intrinsics, self.H, self.W, -1) + + # delta polar/azimuth/radius to default view + delta_polar = thetas - self.opt.default_polar + delta_azimuth = phis - self.opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - self.opt.default_radius + + data = { + 'H': self.H, + 'W': self.W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'dir': dirs, + 'mvp': mvp, + 'poses': poses, + 'polar': delta_polar, + 'azimuth': delta_azimuth, + 'radius': delta_radius, + } + + return data + + def dataloader(self, batch_size=None): + batch_size = batch_size or self.opt.batch_size + loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self + return loader \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/renderer.py b/stable-dreamfusion-3DPortrait/nerf/renderer.py new file mode 100644 index 0000000..d141ae0 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/renderer.py @@ -0,0 +1,1190 @@ +import os +import math +import cv2 +import trimesh +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import nvdiffrast.torch as dr + +import mcubes +import raymarching +from meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction +from .utils import custom_meshgrid, safe_normalize + + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + +@torch.cuda.amp.autocast(enabled=False) +def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05): + # rays: [B, N, 3], [B, N, 3] + # bound: int, radius for ball or half-edge-length for cube + # return near [B, N, 1], far [B, N, 1] + + radius = rays_o.norm(dim=-1, keepdim=True) + + if type == 'sphere': + near = radius - bound # [B, N, 1] + far = radius + bound + + elif type == 'cube': + tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] + tmax = (bound - rays_o) / (rays_d + 1e-15) + near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] + far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] + # if far < near, means no intersection, set both near and far to inf (1e9 here) + mask = far < near + near[mask] = 1e9 + far[mask] = 1e9 + # restrict near to a minimal value + near = torch.clamp(near, min=min_near) + + return near, far + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class DMTet(): + def __init__(self, device): + self.device = device + self.triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [ 1, 0, 2, -1, -1, -1], + [ 4, 0, 3, -1, -1, -1], + [ 1, 4, 2, 1, 3, 4], + [ 3, 1, 5, -1, -1, -1], + [ 2, 3, 0, 2, 5, 3], + [ 1, 4, 0, 1, 5, 4], + [ 4, 2, 5, -1, -1, -1], + [ 4, 5, 2, -1, -1, -1], + [ 4, 1, 0, 4, 5, 1], + [ 3, 2, 0, 3, 5, 2], + [ 1, 3, 5, -1, -1, -1], + [ 4, 1, 2, 4, 3, 1], + [ 3, 0, 4, -1, -1, -1], + [ 2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device=device) + self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device=device) + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:,0] > edges_ex2[:,1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1-order, dim=1) + + return torch.stack([a, b],-1) + + def __call__(self, pos_nx3, sdf_n, tet_fx4): + # pos_nx3: [N, 3] + # sdf_n: [N] + # tet_fx4: [F, 4] + + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) + occ_sum = torch.sum(occ_fx4, -1) # [F,] + valid_tets = (occ_sum>0) & (occ_sum<4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device=self.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) + edges_to_interp_sdf[:,-1] *= -1 + + denominator = edges_to_interp_sdf.sum(1,keepdim = True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1,6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=self.device)) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat(( + torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), + torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), + ), dim=0) + + return verts, faces + +def compute_edge_to_face_mapping(attr_idx): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +@torch.cuda.amp.autocast(enabled=False) +def normal_consistency(face_normals, t_pos_idx): + + tris_per_edge = compute_edge_to_face_mapping(t_pos_idx) + + # Fetch normals for both faces sharind an edge + n0 = face_normals[tris_per_edge[:, 0], :] + n1 = face_normals[tris_per_edge[:, 1], :] + + # Compute error metric based on normal difference + term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0) + term = (1.0 - term) + + return torch.mean(torch.abs(term)) + + +def laplacian_uniform(verts, faces): + + V = verts.shape[0] + F = faces.shape[0] + + # Neighbor indices + ii = faces[:, [1, 2, 0]].flatten() + jj = faces[:, [2, 0, 1]].flatten() + adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1) + adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float) + + # Diagonal indices + diag_idx = adj[0] + + # Build the sparse matrix + idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) + values = torch.cat((-adj_values, adj_values)) + + # The coalesce operation sums the duplicate indices, resulting in the + # correct diagonal + return torch.sparse_coo_tensor(idx, values, (V,V)).coalesce() + + +@torch.cuda.amp.autocast(enabled=False) +def laplacian_smooth_loss(verts, faces): + with torch.no_grad(): + L = laplacian_uniform(verts, faces.long()) + loss = L.mm(verts) + loss = loss.norm(dim=1) + loss = loss.mean() + return loss + + +class NeRFRenderer(nn.Module): + def __init__(self, opt): + super().__init__() + + self.opt = opt + self.bound = opt.bound + self.cascade = 1 + math.ceil(math.log2(opt.bound)) + self.grid_size = 128 + self.max_level = None + self.dmtet = opt.dmtet + self.cuda_ray = opt.cuda_ray + self.taichi_ray = opt.taichi_ray + self.min_near = opt.min_near + self.density_thresh = opt.density_thresh + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + self.glctx = None + + # extra state for cuda raymarching + if self.cuda_ray: + # density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + + if self.dmtet: + # load dmtet vertices + tets = np.load('tets/{}_tets.npz'.format(self.opt.tet_grid_size)) + self.verts = - torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * 2 # covers [-1, 1] + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.tet_scale = torch.tensor([1, 1, 1], dtype=torch.float32, device='cuda') + self.dmtet_model = DMTet('cuda') + + # vert sdf and deform + sdf = torch.nn.Parameter(torch.zeros_like(self.verts[..., 0]), requires_grad=True) + self.register_parameter('sdf', sdf) + deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) + self.register_parameter('deform', deform) + + edges = torch.tensor([0,1, 0,2, 0,3, 1,2, 1,3, 2,3], dtype=torch.long, device="cuda") # six edges for each tetrahedron. + all_edges = self.indices[:,edges].reshape(-1,2) # [M * 6, 2] + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + if self.opt.h <= 2048 and self.opt.w <= 2048: + self.glctx = dr.RasterizeCudaContext() + else: + self.glctx = dr.RasterizeGLContext() + + if self.taichi_ray: + from einops import rearrange + from taichi_modules import RayMarcherTaichi + from taichi_modules import VolumeRendererTaichi + from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi + from taichi_modules import raymarching_test as raymarching_test_taichi + from taichi_modules import composite_test as composite_test_fw + from taichi_modules import packbits as packbits_taichi + self.rearrange = rearrange + self.packbits_taichi = packbits_taichi + self.ray_aabb_intersector = RayAABBIntersectorTaichi + self.raymarching_test_taichi = raymarching_test_taichi + self.composite_test_fw = composite_test_fw + self.ray_marching = RayMarcherTaichi(batch_size=4096) # TODO: hard encoded batch size + self.volume_render = VolumeRendererTaichi(batch_size=4096) # TODO: hard encoded batch size + # density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + + @torch.no_grad() + def density_blob(self, x): + # x: [B, N, 3] + + d = (x ** 2).sum(-1) + + if self.opt.density_activation == 'exp': + g = self.opt.blob_density * torch.exp(- d / (2 * self.opt.blob_radius ** 2)) + else: + g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius) + + return g + + def forward(self, x, d): + raise NotImplementedError() + + def density(self, x): + raise NotImplementedError() + + def reset_extra_state(self): + if not (self.cuda_ray or self.taichi_ray): + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + + @torch.no_grad() + def export_mesh(self, path, resolution=None, decimate_target=-1, S=128): + + if self.opt.dmtet: + + sdf = self.sdf + deform = torch.tanh(self.deform) / self.opt.tet_grid_size + + vertices, triangles = self.dmtet_model(self.verts + deform, sdf, self.indices) + + vertices = vertices.detach().cpu().numpy() + triangles = triangles.detach().cpu().numpy() + + else: + + if resolution is None: + resolution = self.grid_size + + if self.cuda_ray: + density_thresh = min(self.mean_density, self.density_thresh) \ + if np.greater(self.mean_density, 0) else self.density_thresh + else: + density_thresh = self.density_thresh + + # TODO: use a larger thresh to extract a surface mesh from the density field, but this value is very empirical... + if self.opt.density_activation == 'softplus': + density_thresh = density_thresh * 25 + + sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32) + + # query + X = torch.linspace(-1, 1, resolution).split(S) + Y = torch.linspace(-1, 1, resolution).split(S) + Z = torch.linspace(-1, 1, resolution).split(S) + + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = self.density(pts.to(self.aabb_train.device)) + sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + + print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})') + + vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) + vertices = vertices / (resolution - 1.0) * 2 - 1 + + # clean + vertices = vertices.astype(np.float32) + triangles = triangles.astype(np.int32) + vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01) + + # decimation + if decimate_target > 0 and triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh(vertices, triangles, decimate_target) + + v = torch.from_numpy(vertices).contiguous().float().to(self.aabb_train.device) + f = torch.from_numpy(triangles).contiguous().int().to(self.aabb_train.device) + + # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + # mesh.export(os.path.join(path, f'mesh.ply')) + + def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): + # v, f: torch Tensor + device = v.device + v_np = v.cpu().numpy() # [N, 3] + f_np = f.cpu().numpy() # [M, 3] + + print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') + + # unwrap uvs + import xatlas + import nvdiffrast.torch as dr + from sklearn.neighbors import NearestNeighbors + from scipy.ndimage import binary_dilation, binary_erosion + + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 4 # for faster unwrap... + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2] + + vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) + ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) + + # render uv maps + uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] + + if ssaa > 1: + h = int(h0 * ssaa) + w = int(w0 * ssaa) + else: + h, w = h0, w0 + + if self.glctx is None: + if h <= 2048 and w <= 2048: + self.glctx = dr.RasterizeCudaContext() + else: + self.glctx = dr.RasterizeGLContext() + + rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] + mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] + + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + + feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32) + + if mask.any(): + xyzs = xyzs[mask] # [M, 3] + + # batched inference to avoid OOM + all_feats = [] + head = 0 + while head < xyzs.shape[0]: + tail = min(head + 640000, xyzs.shape[0]) + results_ = self.density(xyzs[head:tail]) + all_feats.append(results_['albedo'].float()) + head += 640000 + + feats[mask] = torch.cat(all_feats, dim=0) + + feats = feats.view(h, w, -1) + mask = mask.view(h, w) + + # quantize [0.0, 1.0] to [0, 255] + feats = feats.cpu().numpy() + feats = (feats * 255).astype(np.uint8) + + ### NN search as an antialiasing ... + mask = mask.cpu().numpy() + + inpaint_region = binary_dilation(mask, iterations=3) + inpaint_region[mask] = 0 + + search_region = mask.copy() + not_search_region = binary_erosion(search_region, iterations=2) + search_region[not_search_region] = 0 + + search_coords = np.stack(np.nonzero(search_region), axis=-1) + inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) + + knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) + _, indices = knn.kneighbors(inpaint_coords) + + feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] + + feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR) + + # do ssaa after the NN search, in numpy + if ssaa > 1: + feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR) + + cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats) + + # save obj (v, vt, f /) + obj_file = os.path.join(path, f'{name}mesh.obj') + mtl_file = os.path.join(path, f'{name}mesh.mtl') + + print(f'[INFO] writing obj mesh to {obj_file}') + with open(obj_file, "w") as fp: + fp.write(f'mtllib {name}mesh.mtl \n') + + print(f'[INFO] writing vertices {v_np.shape}') + for v in v_np: + fp.write(f'v {v[0]} {v[1]} {v[2]} \n') + + print(f'[INFO] writing vertices texture coords {vt_np.shape}') + for v in vt_np: + fp.write(f'vt {v[0]} {1 - v[1]} \n') + + print(f'[INFO] writing faces {f_np.shape}') + fp.write(f'usemtl mat0 \n') + for i in range(len(f_np)): + fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n") + + with open(mtl_file, "w") as fp: + fp.write(f'newmtl mat0 \n') + fp.write(f'Ka 1.000000 1.000000 1.000000 \n') + fp.write(f'Kd 1.000000 1.000000 1.000000 \n') + fp.write(f'Ks 0.000000 0.000000 0.000000 \n') + fp.write(f'Tr 1.000000 \n') + fp.write(f'illum 1 \n') + fp.write(f'Ns 0.000000 \n') + fp.write(f'map_Kd {name}albedo.png \n') + + _export(v, f) + + def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): + # rays_o, rays_d: [B, N, 3] + # bg_color: [BN, 3] in range [0, 1] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + results = {} + + # choose aabb + aabb = self.aabb_train if self.training else self.aabb_infer + + # sample steps + # nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) + # nears.unsqueeze_(-1) + # fars.unsqueeze_(-1) + nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near) + + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3] + + #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') + + z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T] + z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T] + z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] + + # perturb z_vals + sample_dist = (fars - nears) / self.opt.num_steps + if perturb: + z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist + #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. + + # generate xyzs + xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] + xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + # query SDF and RGB + density_outputs = self.density(xyzs.reshape(-1, 3)) + + #sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T] + for k, v in density_outputs.items(): + density_outputs[k] = v.view(N, self.opt.num_steps, -1) + + # upsample z_vals (nerf-like) + if self.opt.upsample_steps > 0: + with torch.no_grad(): + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + + alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] + + # sample new z_vals + z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] + new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t] + + new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] + new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + # only forward new points to save computation + new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) + #new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t] + for k, v in new_density_outputs.items(): + new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1) + + # re-order + z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] + z_vals, z_index = torch.sort(z_vals, dim=1) + + xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] + xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) + + for k in density_outputs: + tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) + density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] + + dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) + light_d = light_d.view(-1, 1, 3).expand_as(xyzs) + for k, v in density_outputs.items(): + density_outputs[k] = v.view(-1, v.shape[-1]) + + dirs = safe_normalize(dirs) + sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d.reshape(-1, 3), ratio=ambient_ratio, shading=shading) + rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] + if normals is not None: + normals = normals.view(N, -1, 3) + + # calculate weight_sum (mask) + weights_sum = weights.sum(dim=-1) # [N] + + # calculate depth + depth = torch.sum(weights * z_vals, dim=-1) + + # calculate color + image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] + + # mix background color + if bg_color is None: + if self.opt.bg_radius > 0: + # use the bg model to calculate bg_color + bg_color = self.background(rays_d) # [N, 3] + else: + bg_color = 1 + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + weights_sum = weights_sum.reshape(*prefix) + + if self.training: + if self.opt.lambda_orient > 0 and normals is not None: + # orientation loss + loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 + results['loss_orient'] = loss_orient.sum(-1).mean() + + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: + normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) + results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() + + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: + normal_image = torch.sum(weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1] + results['normal_image'] = normal_image + + results['image'] = image + results['depth'] = depth + results['weights'] = weights + results['weights_sum'] = weights_sum + + return results + + + def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs): + # rays_o, rays_d: [B, N, 3] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # B * N, in fact + device = rays_o.device + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer) + + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3] + + results = {} + + if self.training: + xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps) + dirs = safe_normalize(dirs) + + if light_d.shape[0] > 1: + flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long() + light_d = light_d[flatten_rays] + + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize) + + # normals related regularizations + if self.opt.lambda_orient > 0 and normals is not None: + # orientation loss + loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 + results['loss_orient'] = loss_orient.mean() + + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: + normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) + results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() + + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: + _, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize) + results['normal_image'] = normal_image + + # weights normalization + results['weights'] = weights + + else: + + # allocate outputs + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < self.opt.max_steps: # hard coded max step + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps) + dirs = safe_normalize(dirs) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize) + + rays_alive = rays_alive[rays_alive >= 0] + #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + + # mix background color + if bg_color is None: + if self.opt.bg_radius > 0: + # use the bg model to calculate bg_color + bg_color = self.background(rays_d) # [N, 3] + else: + bg_color = 1 + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + + depth = depth.view(*prefix) + + weights_sum = weights_sum.reshape(*prefix) + + results['image'] = image + results['depth'] = depth + results['weights_sum'] = weights_sum + + return results + + @torch.no_grad() + def init_tet(self, mesh=None): + + if mesh is not None: + # normalize mesh + scale = 0.8 / np.array(mesh.bounds[1] - mesh.bounds[0]).max() + center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2 + mesh.vertices = (mesh.vertices - center) * scale + + # init scale + # self.tet_scale = torch.from_numpy(np.abs(mesh.vertices).max(axis=0) + 1e-1).to(self.verts.dtype).cuda() + self.tet_scale = torch.from_numpy(np.array([np.abs(mesh.vertices).max()]) + 1e-1).to(self.verts.dtype).cuda() + self.verts = self.verts * self.tet_scale + + # init sdf + import cubvh + BVH = cubvh.cuBVH(mesh.vertices, mesh.faces) + sdf, _, _ = BVH.signed_distance(self.verts, return_uvw=False, mode='watertight') + sdf *= -10 # INNER is POSITIVE, also make it stronger + self.sdf.data += sdf.to(self.sdf.data.dtype).clamp(-1, 1) + + else: + + if self.cuda_ray: + density_thresh = min(self.mean_density, self.density_thresh) + else: + density_thresh = self.density_thresh + + if self.opt.density_activation == 'softplus': + density_thresh = density_thresh * 25 + + # init scale + sigma = self.density(self.verts)['sigma'] # verts covers [-1, 1] now + mask = sigma > density_thresh + valid_verts = self.verts[mask] + self.tet_scale = valid_verts.abs().amax(dim=0) + 1e-1 + self.verts = self.verts * self.tet_scale + + # init sigma + sigma = self.density(self.verts)['sigma'] # new verts + self.sdf.data += (sigma - density_thresh).clamp(-1, 1) + + print(f'[INFO] init dmtet: scale = {self.tet_scale}') + + + def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs): + # mvp: [B, 4, 4] + + device = mvp.device + campos = rays_o[:, 0, :] # only need one ray per batch + + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3] + + results = {} + + # get mesh + sdf = self.sdf + deform = torch.tanh(self.deform) / self.opt.tet_grid_size + + verts, faces = self.dmtet_model(self.verts + deform, sdf, self.indices) + + # get normals + i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2] + v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :] + + faces = faces.int() + + face_normals = torch.cross(v1 - v0, v2 - v0) + face_normals = safe_normalize(face_normals) + + vn = torch.zeros_like(verts) + vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) + + # rasterization + verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1), + mvp.permute(0,2,1)).float() # [B, N, 4] + rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w)) + + alpha = (rast[..., 3:] > 0).float() + xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3] + normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces) + normal = safe_normalize(normal) + + xyzs = xyzs.view(-1, 3) + mask = (rast[..., 3:] > 0).view(-1).detach() + + # do the lighting here since we have normal from mesh now. + albedo = torch.zeros_like(xyzs, dtype=torch.float32) + if mask.any(): + masked_albedo = self.density(xyzs[mask])['albedo'] + albedo[mask] = masked_albedo.float() + albedo = albedo.view(-1, h, w, 3) + + # these two modes lead to no parameters to optimize if using --lock_geo. + if self.opt.lock_geo and shading in ['textureless', 'normal']: + shading = 'lambertian' + + if shading == 'albedo': + color = albedo + elif shading == 'textureless': + lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0) + color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0) + color = albedo * lambertian.unsqueeze(-1) + + color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] + alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1] + + # mix background color + if bg_color is None: + if self.opt.bg_radius > 0: + # use the bg model to calculate bg_color + bg_color = self.background(rays_d) # [N, 3] + else: + bg_color = 1 + + if torch.is_tensor(bg_color) and len(bg_color.shape) > 1: + bg_color = bg_color.view(-1, h, w, 3) + + depth = rast[:, :, :, [2]] # [B, H, W] + color = color + (1 - alpha) * bg_color + + results['depth'] = depth + results['image'] = color + results['weights_sum'] = alpha.squeeze(-1) + + if self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0: + normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] + results['normal_image'] = normal_image + + # regularizations + if self.training: + if self.opt.lambda_mesh_normal > 0: + results['normal_loss'] = normal_consistency(face_normals, faces) + if self.opt.lambda_mesh_laplacian > 0: + results['lap_loss'] = laplacian_smooth_loss(verts, faces) + + return results + + def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # pre-calculate near far + exp_step_factor = kwargs.get('exp_step_factor', 0.) + MAX_SAMPLES = 1024 + NEAR_DISTANCE = 0.01 + center = torch.zeros(1, 3) + half_size = torch.ones(1, 3) + _, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1) + hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE + + # TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently... + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) + light_d = safe_normalize(light_d) + + results = {} + + if self.training: + rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES) + dirs = safe_normalize(dirs) + # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) + + # normals related regularizations + if self.opt.lambda_orient > 0 and normals is not None: + # orientation loss + loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 + results['loss_orient'] = loss_orient.mean() + + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: + normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) + results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() + + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: + _, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) + results['normal_image'] = normal_image + + # weights normalization + results['weights'] = weights + + else: + + # allocate outputs + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = hits_t[:, 0, 0] + step = 0 + + min_samples = 1 if exp_step_factor == 0 else 4 + + while step < self.opt.max_steps: # hard coded max step + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + # n_step = max(min(N // n_alive, 8), 1) + n_step = max(min(N // n_alive, 64), min_samples) + + xyzs, dirs, deltas, ts, N_eff_samples = \ + self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive, + self.density_bitfield, self.cascade, + self.bound, exp_step_factor, + self.grid_size, MAX_SAMPLES, n_step) + + xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c') + dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c') + dirs = safe_normalize(dirs) + valid_mask = ~torch.all(dirs == 0, dim=1) + if valid_mask.sum() == 0: + break + + sigmas = torch.zeros(len(xyzs), device=device) + rgbs = torch.zeros(len(xyzs), 3, device=device) + normals = torch.zeros(len(xyzs), 3, device=device) + + sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading) + rgbs[valid_mask] = _rgbs.float() + sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step) + rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step) + if normals is not None: + normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step) + + self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive, + kwargs.get('T_threshold', 1e-4), N_eff_samples, + weights_sum, depth, image) + + rays_alive = rays_alive[rays_alive >= 0] + + step += n_step + + # mix background color + if bg_color is None: + if self.opt.bg_radius > 0: + # use the bg model to calculate bg_color + bg_color = self.background(rays_d) # [N, 3] + else: + bg_color = 1 + + image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color + image = image.view(*prefix, 3) + + depth = depth.view(*prefix) + + weights_sum = weights_sum.reshape(*prefix) + + results['image'] = image + results['depth'] = depth + results['weights_sum'] = weights_sum + + return results + + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + + if not (self.cuda_ray or self.taichi_ray): + return + + ### update density grid + tmp_grid = - torch.ones_like(self.density_grid) + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() + # assign + tmp_grid[cas, indices] = sigmas + # ema update + valid_mask = self.density_grid >= 0 + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid[valid_mask]).item() + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + if self.cuda_ray: + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + elif self.taichi_ray: + self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield) + + # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}') + + + def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3] + # return: pred_rgb: [B, N, 3] + B, N = rays_o.shape[:2] + device = rays_o.device + + if self.dmtet: + results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs) + elif self.cuda_ray: + results = self.run_cuda(rays_o, rays_d, **kwargs) + elif self.taichi_ray: + results = self.run_taichi(rays_o, rays_d, **kwargs) + else: + if staged: + depth = torch.empty((B, N), device=device) + image = torch.empty((B, N, 3), device=device) + weights_sum = torch.empty((B, N), device=device) + + for b in range(B): + head = 0 + while head < N: + tail = min(head + max_ray_batch, N) + results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) + depth[b:b+1, head:tail] = results_['depth'] + weights_sum[b:b+1, head:tail] = results_['weights_sum'] + image[b:b+1, head:tail] = results_['image'] + head += max_ray_batch + + results = {} + results['depth'] = depth + results['image'] = image + results['weights_sum'] = weights_sum + + else: + results = self.run(rays_o, rays_d, **kwargs) + + return results diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/__init__.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/custom_ops.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/custom_ops.py new file mode 100644 index 0000000..ed2524f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/custom_ops.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/misc.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/misc.py new file mode 100644 index 0000000..c0ae67e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/misc.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +from nerf import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + else: + print(f'{name} is not in src_module, init it using random valua!') + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/__init__.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cpp b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cpp new file mode 100644 index 0000000..ee6f6d0 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cpp @@ -0,0 +1,103 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cu new file mode 100644 index 0000000..71ca390 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.cu @@ -0,0 +1,177 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.h b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.h new file mode 100644 index 0000000..8994bfb --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.h @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.py new file mode 100644 index 0000000..b4028ad --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/bias_act.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import numpy as np +import torch +from nerf import dnnlib + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_gradfix.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000..9a177cc --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input, weight) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, weight): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1] + + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_resample.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000..d46f4dd --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cpp b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cpp new file mode 100644 index 0000000..4f55466 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cpp @@ -0,0 +1,304 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "filtered_lrelu.h" + +//------------------------------------------------------------------------ + +static std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, + int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) + { + // No kernel found - return empty tensors and indicate missing kernel with return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) + { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = { 0 }; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] + { + if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write modes. + if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + } + }); + TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) + { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } + else + { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) + { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] + { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); + return so; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. + m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. +} + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cu new file mode 100644 index 0000000..aaac954 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.cu @@ -0,0 +1,1288 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "filtered_lrelu.h" +#include + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ + MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template struct InternalType; +template <> struct InternalType +{ + typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } + __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ + ((B)==2) ? ((int)((A)+1) >> 1) : \ + ((B)==4) ? ((int)((A)+3) >> 2) : \ + (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers of two. +template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) +{ + if ((N & (N-1)) && N <= 256) + y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i/N; + + x = i - y*N; +} + +// Type cast stride before reading it. +template __device__ __forceinline__ T get_stride(const int64_t& x) +{ + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. +__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) +{ + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) + { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer for main kernel. +template static cudaError_t copy_filters(cudaStream_t stream) +{ + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) +{ + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); + static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); + static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); + static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); + static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); + static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : + (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUFD) ? szIn : + -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : + (filterMode == MODE_FUSD) ? szUpXY : + (filterMode == MODE_SUFD) ? szUpX : + (filterMode == MODE_FUFD) ? szUpXY : + -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) + { + // Allocate shared memory arrays here. + __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } + else + { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] + if (filterMode == MODE_SUSD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } + else if (filterMode == MODE_FUSD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } + else if (filterMode == MODE_SUFD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } + else if (filterMode == MODE_FUFD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + + // Inner tile loop. + #pragma unroll 1 + for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) + { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on first tile. + if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); + #pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) + { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; + + bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); + if (!skip) + s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) + { + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInX == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInX == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + s_tileUpX[dst+2] = v.z; + s_tileUpX[dst+3] = v.w; + } + } + else if (up == 2) + { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) + { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInY == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInY == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } + if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } + else if (up == 2) + { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) + { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) + s_tileUpXY[dst + tileUpW] = v.y; + } + else + { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) + { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } + else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) + { + // Full upsampling filter. + + if (up == 2) + { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + + #define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) \ + { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 0) } + if (tap0y == 0 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 1) } + if (tap0y == 1 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 0) } + if (tap0y == 1 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 1) } + + #undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) + { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } + else if (up == 1) + { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } + else + { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + else + { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } + else if (signRead) + { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) + { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) + { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) + { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + s_tileDownX[idx+2] = v.z; + s_tileDownX[idx+3] = v.w; + } + } + else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) + { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + } + } + else + { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) + { + // Full downsampling filter. + if (down == 2) + { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int sy = 0; sy < fdSize; sy++) + #pragma unroll + for (int sx = 0; sx < fdSize; sx++) + { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) + { + index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; + } + } + } + else if (down == 1 && !downInline) + { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) + break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. +// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) + { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) + { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) + { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) + { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } + else if (signRead) + { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) + { + uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } + else + { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) + v *= p.slope; + if (fabsf(v) > p.clamp) + v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template void* choose_filtered_lrelu_act_kernel(void) +{ + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) +{ + filtered_lrelu_kernel_spec s = { 0 }; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ + { \ + static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ + static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*)filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. + // Kernels that use more shared memory must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 + + #undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.h b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.h new file mode 100644 index 0000000..f2bfd1d --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.h @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params +{ + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params +{ + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec +{ + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template void* choose_filtered_lrelu_act_kernel(void); +template cudaError_t copy_filters(cudaStream_t stream); + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.py new file mode 100644 index 0000000..2047b7e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import warnings + +from .. import custom_ops +from .. import misc +from . import upfirdn2d +from . import bias_act + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='filtered_lrelu_plugin', + sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], + headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + +#---------------------------------------------------------------------------- + +def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): + r"""Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if provided (`b`). + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. + + 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking + it so that the footprint of all output pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + fu: Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + fd: Float32 downsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The length of vector must must match the channel dimension of `x`. + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor. (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + gain: Overall scaling factor for signal magnitude (default: sqrt(2)). + slope: Slope on the negative side of leaky ReLU (default: 0.2). + clamp: Maximum magnitude for leaky ReLU output (default: None). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) + return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using + existing `upfirdn2n()` and `bias_act()` ops. + """ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + fu_w, fu_h = _get_filter_size(fu) + fd_w, fd_h = _get_filter_size(fd) + if b is not None: + assert isinstance(b, torch.Tensor) and b.dtype == x.dtype + misc.assert_shape(b, [x.shape[1]]) + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = x.shape + in_dtype = x.dtype + out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + # Compute using existing ops. + x = bias_act.bias_act(x=x, b=b) # Apply bias. + x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. + x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Check output shape & dtype. + misc.assert_shape(x, [batch_size, channels, out_h, out_w]) + assert x.dtype == in_dtype + return x + +#---------------------------------------------------------------------------- + +_filtered_lrelu_cuda_cache = dict() + +def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). + if fu is None: + fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if fd is None: + fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert 1 <= fu.ndim <= 2 + assert 1 <= fd.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. + if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: + fu = fu.square()[None] + if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: + fd = fd.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if b is None: + b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) + + # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. + strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if x.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): + warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) + y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because + # only the bit-packed sign tensor is retained for gradient computation. + if return_code < 0: + warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) + + y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. + y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(fu, fd, (si if si.numel() else so)) + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + fu, fd, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None; assert not ctx.needs_input_grad[1] + dfd = None; assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None; assert not ctx.needs_input_grad[4] + dsx = None; assert not ctx.needs_input_grad[5] + dsy = None; assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, + xw * up - yw * down + px0 - (up - 1), + (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up ** 2) / (down ** 2) + ff = (not flip_filter) + sx = sx - (fu.shape[-1] - 1) + px0 + sy = sy - (fu.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_ns.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_ns.cu new file mode 100644 index 0000000..8a3eae4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_ns.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for no signs mode (no gradients required). + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_rd.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_rd.cu new file mode 100644 index 0000000..3cd43ec --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_rd.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign read mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_wr.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_wr.cu new file mode 100644 index 0000000..bc2fa06 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/filtered_lrelu_wr.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign write mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/fma.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/fma.py new file mode 100644 index 0000000..5458116 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/fma.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/grid_sample_gradfix.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000..35d9472 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cpp b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000..c1769c3 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cu b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000..7d182d7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,388 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.h b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.h new file mode 100644 index 0000000..d5de893 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: LicenseRef-NvidiaProprietary + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000..5d63471 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,391 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import torch + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/persistence.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/persistence.py new file mode 100644 index 0000000..9e110d6 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/persistence.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +# how to import dnnlib + +from nerf import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/torch_utils/training_stats.py b/stable-dreamfusion-3DPortrait/nerf/torch_utils/training_stats.py new file mode 100644 index 0000000..636dd7f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/torch_utils/training_stats.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/__init__.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/__init__.py new file mode 100644 index 0000000..dfebd04 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/aligned_smpl.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/aligned_smpl.py new file mode 100644 index 0000000..33b7f7e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/aligned_smpl.py @@ -0,0 +1,446 @@ + +import os.path as osp + +import numpy as np +import torch +from nerf.torch_utils import misc + +import trimesh +import pickle + + +import os +# os.environ["PYOPENGL_PLATFORM"] = "egl" +# os.environ["PYOPENGL_PLATFORM"] = "osmesa" +import pyrender + +class AlignedSMPL(torch.nn.Module): + def __init__(self, model,batch_size): + super().__init__() + self.batch_size = batch_size + smpl_joint_regressor = torch.from_numpy( + np.load('transfer_data/smpl_joint_regressor.npy')).float().cuda().contiguous() + self.register_buffer('smpl_joint_regressor', smpl_joint_regressor) + + self.model = model + faces = torch.from_numpy(self.model.faces.astype(np.int32)).cuda().long().contiguous() + self.register_buffer('faces', faces) + + + def set_model(self, model): + self.model = model + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def get_align_coordinate(self, vertices): + # 30 x 6890 + batch_size = vertices.shape[0] + smpl_joints = torch.bmm(self.smpl_joint_regressor[None, :, :].repeat(batch_size, 1, 1), vertices) + align_joint_coordinate = smpl_joints[:,12, None, :] # Neck + return align_joint_coordinate + + def render_mesh(self, img, mesh, face, cam_param, color=(1.0, 1.0, 0.9, 1.0), cam_pose=None): + # mesh + mesh = trimesh.Trimesh(mesh, face) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=color) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + focal, princpt = cam_param['focal'], cam_param['princpt'] + camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + + if cam_pose is not None: + scene.add(camera, pose=cam_pose) + else: + scene.add(camera) + # scene.add(camera) + # print('camera pose in scene ', scene.get_pose(scene._main_camera_node)) + # renderer + renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) + + # light + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8) + # light_pose = np.eye(4) + # light_pose[:3, 3] = np.array([0, -1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([0, 1, 1]) + # scene.add(light, pose=light_pose) + # light_pose[:3, 3] = np.array([1, 1, 2]) + # scene.add(light, pose=light_pose) + + light_pose = np.eye(4) + light_pose[:3, 3] = np.array([0, 0, -1]) + scene.add(light, pose=light_pose) + + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + scene.add(light, pose=cam_pose) + light_pose[:3, 3] = np.array([1, 1, -4]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([-1, 0, -1]) + scene.add(light, pose=light_pose) + light_pose[:3, 3] = np.array([0.2469, 1.8828, -2.4473]) + scene.add(light, pose=light_pose) + + # render + rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + rgb = rgb[:, :, :3].astype(np.float32) + valid_mask = (depth > 0)[:, :, None] + + # save to image + img = rgb * valid_mask + img * (1 - valid_mask) + return img.astype(np.uint8) + + def render_depth(self, img, mesh, face, cam_param, color=(1.0, 1.0, 0.9, 1.0), cam_pose=None): + # mesh + mesh = trimesh.Trimesh(mesh, face) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=color) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False) + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + focal, princpt = cam_param['focal'], cam_param['princpt'] + camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1]) + + if cam_pose is not None: + scene.add(camera, pose=cam_pose) + else: + scene.add(camera) + # scene.add(camera) + # print('camera pose in scene ', scene.get_pose(scene._main_camera_node)) + # renderer + renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0) + + # render + rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + #rgb = rgb[:, :, :3].astype(np.float32) + valid_mask = (depth > 0)[:, :, None] + + # save to image + depth = depth * valid_mask + img * (1 - valid_mask) + return depth.astype(np.uint8) + + + def get_projected_vertex(self, mesh, world2screen_matrix): + # mesh = np.concatenate([mesh, np.ones((mesh.shape[0], 1))], axis=1) # N x 4 + mesh = torch.cat([mesh, torch.ones((mesh.shape[0], 1)).to(mesh.device)], dim=1) # N x 4 + points_image = world2screen_matrix @ mesh.T # 4,N + points_image = points_image[:3, :] # 3,N + + points_on_input_image = points_image / points_image[2, :] + points_on_input_image = points_on_input_image[:2, :].T # 30,2 + + return points_on_input_image + + + def generate_shaped_smpl(self, betas, scale, transl): + if betas is not None: + raise NotImplementedError + else: + betas = None + if scale is not None: + raise NotImplementedError + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if transl is not None: + raise NotImplementedError + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + + # body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + # apply beta, alignment, translation and scale + shaed_output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=None, + return_shaped=False + ) + vertices_no_pose = shaed_output.vertices + joints_no_pose = shaed_output.joints + + + align_joint_coordinate = self.get_align_coordinate(vertices_no_pose) # B,1,3 + vertices_no_pose -= align_joint_coordinate + joints_no_pose -= align_joint_coordinate + + vertices_no_pose += transl.view(self.batch_size, 1, 3) + joints_no_pose += transl.view(self.batch_size, 1, 3) + + vertices_no_pose *= scale.view(self.batch_size, 1, 1) + joints_no_pose *= scale.view(self.batch_size, 1, 1) + + nose_2d = joints_no_pose[:,86:90,:] # B, 4, 3 + eye_right_2d = joints_no_pose[:,95: 101,:] # B, 6, 3 + eye_left_2d = joints_no_pose[:,101: 107,:] # B, 6, 3 + + # points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + face_points = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=1) # B, 16, 3 + + #transformation_matrix = self.compute_transformation_matrix(face_points) + + res = { + 'vertices': vertices_no_pose, + 'align_joint_coordinate': align_joint_coordinate, + 'face_points': face_points, + } + return res + + def generate_posed_smpl(self, betas, scale, transl, body_pose, align_joint_coordinate): + batch_size = body_pose.shape[0] + if betas is not None: + raise NotImplementedError + else: + betas = None + if scale is not None: + raise NotImplementedError + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if transl is not None: + raise NotImplementedError + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + misc.assert_shape(body_pose, [self.batch_size, 6]) + + # apply beta, alignment, translation and scale + + # apply beta, pose, alignment, translation and scale + # mask pose except 11 and 14 + body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + body_pose_fill[:, 11, :] = body_pose[:, :3] + body_pose_fill[:, 14, :] = body_pose[:, 3:] + + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + + + body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + + output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=body_pose_fill, + return_shaped=True + ) + vertices = output.vertices + joints = output.joints + + # align vertices and joints + vertices -= align_joint_coordinate + joints -= align_joint_coordinate + + # additional translation + vertices += transl.view(self.batch_size, 1, 3) + joints += transl.view(self.batch_size, 1, 3) + + # additional scale + vertices *= scale.view(self.batch_size, 1, 1) + joints *= scale.view(self.batch_size, 1, 1) + + nose_2d = joints[:, 86:90, :] # B, 4, 3 + eye_right_2d = joints[:, 95: 101, :] # B, 6, 3 + eye_left_2d = joints[:, 101: 107, :] # B, 6, 3 + + # points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + face_points = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=1) # B, 16, 3 + + res = { + 'vertices': vertices, + 'face_points': face_points + } + + return res + + + + def get_depth(self,vert, resolution=256, cameras=None): + + faces = self.model.faces + # compute the transformation matrix with eg3d + intrisics_standard_dict = {"focal": [5000.0 / 1024 * resolution / 0.75, 5000.0 / 1024 * resolution / 0.75], + "princpt": [resolution / 2, resolution / 2]} + # intrisics_standard = np.array( [[5000.0, 0.0, resolution/2, 0.0], [0.0, 5000.0, resolution/2.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) + # normalized_transformation_in_realworld = np.array(render_kwargs['world2camera_matrix']) + R = np.eye(3) + angle = np.pi + R[1, 1] = np.cos(angle) + R[1, 2] = -np.sin(angle) + R[2, 1] = np.sin(angle) + R[2, 2] = np.cos(angle) + + R = torch.from_numpy(R).float().to(self.model.shapedirs.device).unsqueeze(0).repeat(self.batch_size, 1, + 1) # self.batch_size x 3 x 3 + + vertices_pyrender = torch.matmul(vert, R) # 1 x 6890 x 3 + # normalized_camerapose_in_pyrender = np.array(render_kwargs['normalized_camerapose_in_pyrender']) + + # color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + images = [] + for i in range(self.batch_size): + camera_pose = cameras[i, :16].reshape(4, 4) + + camerapose_in_pyrender = np.linalg.inv(camera_pose) + camerapose_in_pyrender[[1, 2]] *= -1 + camerapose_in_pyrender = np.linalg.inv(camerapose_in_pyrender) + + # print(vertices_pyrender.shape, vertices_pyrender[i].shape,camerapose_in_pyrender.shape) + image_camera_rotate = self.render_depth(np.ones((resolution, resolution, 3)) * 255, + vertices_pyrender[i].detach().cpu().numpy(), faces, + intrisics_standard_dict, + color=(0.4, 0.5, 0.9, 1.0), + cam_pose=camerapose_in_pyrender) + + image_camera_rotate = image_camera_rotate[None, :, :, :] # 1 x 256 x 256 x 3 + image_camera_rotate = np.transpose(image_camera_rotate, (0, 3, 1, 2)) # 1 x 3 x 256 x 256 + images.append(image_camera_rotate) + + images = np.concatenate(images, axis=0) + return images + # + def get_visualization(self, shape_pose_params, resolution=256, cameras=None): + # apply beta, alignment, translation and scale + if 'betas' in shape_pose_params: + raise NotImplementedError + betas = shape_pose_params['betas'] + misc.assert_shape(betas, [self.batch_size, self.num_betas]) + else: + betas = None + # scale = shape_pose_params['scale'] + # transl = shape_pose_params['transl'] + if 'scale' in shape_pose_params: + raise NotImplementedError + scale = shape_pose_params['scale'] + misc.assert_shape(scale, [self.batch_size, 1]) + else: + scale = torch.ones([self.batch_size, 1]).to(self.model.shapedirs.device) + if 'transl' in shape_pose_params: + raise NotImplementedError + transl = shape_pose_params['transl'] + misc.assert_shape(transl, [self.batch_size, 3]) + else: + transl = torch.zeros([self.batch_size, 3]).to(self.model.shapedirs.device) + + + body_pose = shape_pose_params['pose'] + + + misc.assert_shape(scale, [self.batch_size, 1]) + misc.assert_shape(transl, [self.batch_size, 3]) + misc.assert_shape(body_pose, [self.batch_size, 6]) + + cameras = cameras.detach().cpu().numpy() # N, 25 + + shaed_output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=None, + return_shaped=False + ) + vertices_no_pose = shaed_output.vertices + faces = self.model.faces + + align_joint_coordinate = self.get_align_coordinate(vertices_no_pose) + vertices_no_pose = vertices_no_pose + vertices_no_pose -= align_joint_coordinate + + vertices_no_pose += transl.view(self.batch_size, 1, 3) + vertices_no_pose *= scale.view(self.batch_size, 1, 1) + + # apply beta, pose, alignment, translation and scale + # mask pose except 11 and 14 + body_pose_fill = torch.zeros((self.batch_size, 23, 3)).to(self.model.shapedirs.device) + body_pose_fill[:, 11, :] = body_pose[:, :3] + body_pose_fill[:, 14, :] = body_pose[:, 3:] + + # # 15 16 for shoulder, we hope the Hands naturally sagging + # body_pose_fill[:, 15, :] = torch.tensor([0.0, 0.0, -np.pi / 2]).to(self.model.shapedirs.device) + # body_pose_fill[:, 16, :] = torch.tensor([0.0, 0.0, np.pi / 2]).to(self.model.shapedirs.device) + + + + body_pose_fill = body_pose_fill.reshape(self.batch_size, -1) + + output = self.model(betas=betas, + expression=None, + return_verts=True, + body_pose=body_pose_fill, + return_shaped=True + ) + vertices = output.vertices + joints = output.joints + + # align vertices and joints + vertices -= align_joint_coordinate + joints -= align_joint_coordinate + + # additional translation + vertices += transl.view(self.batch_size, 1, 3) + joints += transl.view(self.batch_size, 1, 3) + + # additional scale + vertices *= scale.view(self.batch_size, 1, 1) + joints *= scale.view(self.batch_size, 1, 1) + + # print(vertices[:,0].min(),vertices[:,0].max(),vertices[:,0].max() - vertices[:,0].min()) + # print(vertices[:,1].min(),vertices[:,1].max(),vertices[:,1].max() - vertices[:,1].min()) + # print(vertices[:,2].min(),vertices[:,2].max(),vertices[:,2].max() - vertices[:,2].min()) + + # nose_2d = joints[86:90] # 4 + # eye_right_2d = joints[95: 101] # 6 + # eye_left_2d = joints[101: 107] # 6 + + #points_3d = np.concatenate([nose_2d, eye_right_2d, eye_left_2d], axis=0) # 16 + #points_3d = torch.cat([nose_2d, eye_right_2d, eye_left_2d], dim=0) # 16 + + # compute the transformation matrix with eg3d + intrisics_standard_dict = {"focal": [5000.0/1024*resolution/0.75, 5000.0/1024*resolution/0.75], "princpt": [resolution/2, resolution/2]} + # intrisics_standard = np.array( [[5000.0, 0.0, resolution/2, 0.0], [0.0, 5000.0, resolution/2.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) + # normalized_transformation_in_realworld = np.array(render_kwargs['world2camera_matrix']) + R = np.eye(3) + angle = np.pi + R[1, 1] = np.cos(angle) + R[1, 2] = -np.sin(angle) + R[2, 1] = np.sin(angle) + R[2, 2] = np.cos(angle) + + R = torch.from_numpy(R).float().to(self.model.shapedirs.device).unsqueeze(0).repeat(self.batch_size, 1, 1) # self.batch_size x 3 x 3 + + vertices_pyrender = torch.matmul(vertices, R) # 1 x 6890 x 3 + #normalized_camerapose_in_pyrender = np.array(render_kwargs['normalized_camerapose_in_pyrender']) + + # color = colorsys.hsv_to_rgb(np.random.rand(), 0.5, 1.0) + images = [] + for i in range(self.batch_size): + camera_pose = cameras[i,:16].reshape(4,4) + + camerapose_in_pyrender = np.linalg.inv(camera_pose) + camerapose_in_pyrender[[1,2]] *= -1 + camerapose_in_pyrender = np.linalg.inv(camerapose_in_pyrender) + + #print(vertices_pyrender.shape, vertices_pyrender[i].shape,camerapose_in_pyrender.shape) + image_camera_rotate = self.render_mesh(np.ones((resolution, resolution, 3)) * 255, + vertices_pyrender[i].detach().cpu().numpy(), faces, + intrisics_standard_dict, + color=(0.4, 0.5, 0.9, 1.0), + cam_pose=camerapose_in_pyrender) + + image_camera_rotate = image_camera_rotate[None, :, :, :] # 1 x 256 x 256 x 3 + image_camera_rotate = np.transpose(image_camera_rotate, (0, 3, 1, 2)) # 1 x 3 x 256 x 256 + images.append(image_camera_rotate) + + images = np.concatenate(images, axis=0) + return images diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/augment.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/augment.py new file mode 100644 index 0000000..7b00a4a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/augment.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Augmentation pipeline from the paper +"Training Generative Adversarial Networks with Limited Data". +Matches the original implementation by Karras et al. at +https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py""" + +import numpy as np +import scipy.signal +import torch +from torch_utils import persistence +from torch_utils import misc +from torch_utils.ops import upfirdn2d +from torch_utils.ops import grid_sample_gradfix +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + +#---------------------------------------------------------------------------- +# Versatile image augmentation pipeline from the paper +# "Training Generative Adversarial Networks with Limited Data". +# +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + +@persistence.persistent_class +class AugmentPipe(torch.nn.Module): + def __init__(self, + xflip=0, rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Image-space corruptions. + self.noise = float(noise) # Probability multiplier for additive RGB noise. + self.cutout = float(cutout) # Probability multiplier for cutout. + self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. + self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, images, debug_percentile=None): + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply x-flip with probability (xflip * strength). + if self.xflip > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 2) + i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) + + # Apply 90 degree rotations with probability (rotate90 * strength). + if self.rotate90 > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 4) + i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 4)) + G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + if self.lumaflip > 0: + i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) + i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + elif num_channels == 6: + images[:, :3] = C[:, :3, :3] @ images[:, :3] + C[:, :3, 3:] + images[:, 3:] = C[:, :3, :3] @ images[:, 3:] + C[:, :3, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + + # Apply additive RGB noise with probability (noise * strength). + if self.noise > 0: + sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std + sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) + if debug_percentile is not None: + sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) + images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma + + # Apply cutout with probability (cutout * strength). + if self.cutout > 0: + size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) + size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) + center = torch.rand([batch_size, 2, 1, 1, 1], device=device) + if debug_percentile is not None: + size = torch.full_like(size, self.cutout_size) + center = torch.full_like(center, debug_percentile) + coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) + coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) + mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) + mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) + mask = torch.logical_or(mask_x, mask_y).to(torch.float32) + images = images * mask + + return images + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/crosssection_utils.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/crosssection_utils.py new file mode 100644 index 0000000..72d49f2 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/crosssection_utils.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +def sample_cross_section(G, ws, resolution=256, w=1.2): + axis=0 + A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij') + A, B = A.reshape(-1, 1), B.reshape(-1, 1) + C = torch.zeros_like(A) + coordinates = [A, B] + coordinates.insert(axis, C) + coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1) + + sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma'] + return sigma.reshape(-1, 1, resolution, resolution) + +# if __name__ == '__main__': +# sample_crossection(None) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dataset.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dataset.py new file mode 100644 index 0000000..4c348ca --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dataset.py @@ -0,0 +1,565 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import os +import numpy as np +import zipfile +import PIL.Image +import json +import torch +import dnnlib +try: + import pyspng +except ImportError: + pyspng = None + +#---------------------------------------------------------------------------- + +def matrix2angle(R): + """ + https://github.com/sizhean/panohead + compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf + refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv + todo: check and debug + Args: + R: (3,3). rotation matrix + Returns: + x: yaw + y: pitch + z: roll + """ + if R[2, 0] > 0.998: + z = 0 + x = np.pi / 2 + y = z + atan2(-R[0, 1], -R[0, 2]) + elif R[2, 0] < -0.998: + z = 0 + x = -np.pi / 2 + y = -z + atan2(R[0, 1], R[0, 2]) + else: + x = asin(R[2, 0]) + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + + if abs(y) > np.pi/2: + if x > 0: + x = np.pi - x + else: + x = -np.pi - x + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + return x, y, z + + +def get_poseangle(eg3dparams): + ''' + https://github.com/sizhean/panohead + ''' + convert = np.array([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ]).astype(np.float32) + + entry_cam = np.array([float(p) for p in eg3dparams][:16]).reshape((4,4)) + + world2cam = np.linalg.inv(entry_cam@convert) + pose = matrix2angle(world2cam[:3,:3]) + angle = [p * 180 / np.pi for p in pose] + + return angle + + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels = False, # Enable conditioning labels? False = label dimension is zero. + xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + random_seed = 0, # Random seed to use when applying max_size. + rebal_raw_idx = None, # Rebalance the dataset by sampling from the raw_idx list + data_rebalance=False, # Rebalance the dataset by sampling from the raw_idx list + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._raw_labels = None + self._raw_poses = None + self._label_shape = None + self._pose_shape = None + + + if data_rebalance: + raise NotImplementedError + assert rebal_raw_idx is not None, "rebal_raw_idx must be provided if data_rebalance is True" + self._raw_idx = rebal_raw_idx + else: + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + + + self._raw_idx = self._filter_samples() + + # Apply max_size. + if (max_size is not None) and (self._raw_idx.size > max_size): + raise NotImplementedError + np.random.RandomState(random_seed).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + def _filter_samples(self): # to be overridden by subclass + raise NotImplementedError + + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels,self._raw_poses = self._load_raw_labels() if self._use_labels else None + + if self._raw_labels is None: + raise Exception("_raw_labels is None") + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + self._raw_labels_std = self._raw_labels.std(0) + + + if self._raw_poses is None: + raise Exception("_raw_poses is None") + self._raw_poses = np.zeros([self._raw_poses[0], 0], dtype=np.float32) + + assert isinstance(self._raw_poses, np.ndarray) + assert self._raw_poses.shape[0] == self._raw_shape[0] + assert self._raw_poses.dtype in [np.float32, np.int64] + if self._raw_poses.dtype == np.int64: + assert self._raw_poses.ndim == 1 + assert np.all(self._raw_poses >= 0) + self._raw_poses_std = self._raw_poses.std(0) + + return self._raw_labels + + def _get_raw_poses(self): + if self._raw_poses is None: + _ = self._get_raw_labels() + #raise Exception("please run _get_raw_labels first") + + return self._raw_poses + + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None, _raw_poses=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + + + + def __getitem__(self, idx): + + + label = self.get_label(idx) + pose = self.get_coarse_pose(idx) + + # image = self._load_raw_image(self._raw_idx[idx]) + # assert isinstance(image, np.ndarray) + # assert list(image.shape) == self.image_shape + # assert image.dtype == np.uint8 + # if self._xflip[idx]: + # assert image.ndim == 3 # CHW + # image = image[:, :, ::-1] + # # # flip label + # # label = self.flip_yaw(label) + # # # flip pose + # # pose[[1, 2, 4, 5]] *= -1 + + image = self.get_image(idx) + + + return image, label,pose + + def flip_yaw(self, c): + pose_matrix = c.copy() + flipped = pose_matrix[:16].reshape(4,4) + flipped[0, 1] *= -1 + flipped[0, 2] *= -1 + flipped[1, 0] *= -1 + flipped[2, 0] *= -1 + flipped[0, 3] *= -1 + + flipped = flipped.reshape(16) + pose_matrix[:16] = flipped + + return pose_matrix + + def get_image(self, idx): + image = self._load_raw_image(self._raw_idx[idx]) + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + + return image.copy() + + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]].copy() + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + + if self._xflip[idx]: + assert label.shape == (25,) + label[[1, 2, 3, 4, 8]] *= -1 + + return label + + def get_coarse_pose(self, idx): + pose = self._get_raw_poses()[self._raw_idx[idx]].copy() + if pose.dtype == np.int64: + raise TypeError("pose should be float32") + onehot = np.zeros(self.pose_shape, dtype=np.float32) + onehot[pose] = 1 + pose = onehot + + if self._xflip[idx]: + pose_flip = pose.copy() + pose_flip[[1, 2, 4, 5]] *= -1 + + return pose_flip + + else: + return pose + + + + def get_details(self, idx): + d = dnnlib.EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = (int(self._xflip[idx]) != 0) + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + # d.pose = self.get_coarse_pose(idx).copy() + + return d + + def get_label_std(self): + return self._raw_labels_std + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def pose_shape(self): + if self._pose_shape is None: + self._get_raw_labels() + if self._raw_poses.dtype == np.int64: + self._pose_shape = [int(np.max(self._raw_poses)) + 1] + else: + self._pose_shape = self._raw_poses.shape[1:] + return list(self._pose_shape) + + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + +#---------------------------------------------------------------------------- + +class ImageFolderDataset(Dataset): + def __init__(self, + path, # Path to directory or zip. + back_repeat = None, + resolution = None, # Ensure specific resolution, None = highest available. + data_rebalance_idx_file = None, + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self.min_yaw = 0 + self.max_yaw = 180 + self.max_pitch = 90 + self.back_repeat = 1 if back_repeat is None else back_repeat + self._path = path + self._zipfile = None + + if os.path.isdir(self._path): + raise NotImplementedError('Does not support directories yet') + self._type = 'dir' + self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} + elif self._file_ext(self._path) == '.zip': + self._type = 'zip' + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError('Path must point to a directory or zip') + + if data_rebalance_idx_file is not None: + raise NotImplementedError('data_rebalance is not implemented yet') + rebal_idx_list_path =data_rebalance_idx_file + #print('load rebal_idx_list from ',rebal_idx_list_path) + with open(rebal_idx_list_path, 'r') as f: + rebal_raw_idx = json.load(f) + rebal_raw_idx = np.array(rebal_raw_idx) + else: + rebal_raw_idx = None + + + PIL.Image.init() + self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) + if len(self._image_fnames) == 0: + raise IOError('No image files found in the specified path') + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): + raise IOError('Image files do not match the specified resolution') + super().__init__(name=name, raw_shape=raw_shape, rebal_raw_idx = rebal_raw_idx,**super_kwargs) + + + def _filter_samples(self): + if self.back_repeat>1: + raw_labels = self._get_raw_labels()[self._raw_idx] + label_list = [] + for entry in raw_labels: + label_list.append(get_poseangle(entry)) + poses = np.array(label_list) + # find [min_yaw, max_yaw] boolean + valid = (np.abs(poses[:,0])>=self.min_yaw) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + # find back boolean: [max(90, self.min_yaw), max_yaw] + back_valid = (np.abs(poses[:,0])>= max(90, self.min_yaw)) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + if not np.all(valid): + print(f"filtering samples by pose: ratio = {valid.sum()}/{len(self._raw_idx)}") + # boolean to index + valid_idx = self._raw_idx[valid] + back_idx = self._raw_idx[back_valid] + front_idx = np.array(list(set(valid_idx) - set(back_idx))) + + front_num = valid.sum()-len(back_idx) + front_back_ratio_min = front_num/2/len(back_idx) + print(f"if back num be the half of front num, at least repeat ({int(front_back_ratio_min)}) times.") + back_repeat = max(int(front_num/2/len(back_idx)), self.back_repeat) + + + + + # TODO: support the repeat times < 1 + # repeat [max(90, self.min_yaw), max_yaw] for multiple times + back_repeat_idx = np.tile(back_idx, back_repeat) + # merge front index and repeated back index + new_idx = np.concatenate((front_idx, back_repeat_idx)) + print(f"Repeat {len(back_idx)} back images till abs({self.max_yaw}) degree {back_repeat} times") + return new_idx + else: + return self._raw_idx + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == 'zip' + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == 'dir': + return open(os.path.join(self._path, fname), 'rb') + if self._type == 'zip': + return self._get_zipfile().open(fname, 'r') + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if pyspng is not None and self._file_ext(fname) == '.png': + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = 'dataset.json' + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)['labels'] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] + labels = np.array(labels) + labels = np.squeeze(labels) + #print('labels shape', labels.shape) # N, 31 + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + + poses = labels[:,25:] + labels = labels[:,:25] + + # print('labels shape', labels.shape) # N, 25 + # print('poses shape', poses.shape) # N, 6 + + return labels, poses + + +#---------------------------------------------------------------------------- + + +class MaskLabeledDataset(ImageFolderDataset): + + def __init__(self, + img_path, # Path to directory or zip. + seg_path, # Path to directory or zip. + back_repeat = None, + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self.min_yaw = 0 + self.max_yaw = 180 + self.max_pitch = 90 + self.back_repeat = 1 if back_repeat is None else back_repeat + super().__init__(path=img_path, back_repeat = None,**super_kwargs) + + self._seg_dataset = ImageFolderDataset(seg_path, **super_kwargs) + + # Build the mapping from seg fname to seg raw index + seg_dict = {os.path.splitext(fname)[0]: idx for idx, fname in enumerate(self._seg_dataset._image_fnames)} + + # Build the mapping from index to seg raw index + self._seg_raw_idx = [] + for raw_idx in self._raw_idx: + fname = self._image_fnames[raw_idx] + key = os.path.splitext(fname)[0] + self._seg_raw_idx.append(seg_dict[key]) + self._seg_raw_idx = np.array(self._seg_raw_idx) + + def _filter_samples(self): + if self.back_repeat>1: + raw_labels = self._get_raw_labels()[self._raw_idx] + label_list = [] + for entry in raw_labels: + label_list.append(get_poseangle(entry)) + poses = np.array(label_list) + # find [min_yaw, max_yaw] boolean + valid = (np.abs(poses[:,0])>=self.min_yaw) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + # find back boolean: [max(90, self.min_yaw), max_yaw] + back_valid = (np.abs(poses[:,0])>= max(90, self.min_yaw)) & (np.abs(poses[:,0])<=self.max_yaw) & (np.abs(poses[:,1])<=self.max_pitch) + if not np.all(valid): + print(f"filtering samples by pose: ratio = {valid.sum()}/{len(self._raw_idx)}") + # boolean to index + valid_idx = self._raw_idx[valid] + back_idx = self._raw_idx[back_valid] + front_idx = np.array(list(set(valid_idx) - set(back_idx))) + + front_num = valid.sum()-len(back_idx) + front_back_ratio_min = front_num/2/len(back_idx) + print(f"if back num be the half of front num, at least repeat ({int(front_back_ratio_min)}) times.") + back_repeat = max(int(front_num/2/len(back_idx)), self.back_repeat) + + + + + # TODO: support the repeat times < 1 + # repeat [max(90, self.min_yaw), max_yaw] for multiple times + back_repeat_idx = np.tile(back_idx, back_repeat) + # merge front index and repeated back index + new_idx = np.concatenate((front_idx, back_repeat_idx)) + print(f"Repeat {len(back_idx)} back images till abs({self.max_yaw}) degree {back_repeat} times") + return new_idx + else: + return self._raw_idx + + + + def __getitem__(self, idx): + # already flipped in the ImageFolderDataset + image = self.get_image(idx) + mask = self._seg_dataset.get_image(idx) + label = self.get_label(idx) + pose = self.get_coarse_pose(idx) + + + return image.copy(), mask.copy(), label,pose + diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dual_discriminator.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dual_discriminator.py new file mode 100644 index 0000000..403c2f6 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/dual_discriminator.py @@ -0,0 +1,502 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Discriminator architectures from the paper +"Efficient Geometry-aware 3D Generative Adversarial Networks".""" + +import numpy as np +import torch +from torch_utils import persistence +from torch_utils.ops import upfirdn2d +from training.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue + + +@persistence.persistent_class +class SingleDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + sr_upsample_factor=1, # Ignored for SingleDiscriminator + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + img = img['image'] + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- + +def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'): + if filter_mode == 'antialiased': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=True) + elif filter_mode == 'classic': + ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2) + ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), + mode='bilinear', align_corners=False) + ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1) + elif filter_mode == 'none': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False) + elif type(filter_mode) == float: + assert 0 < filter_mode < 1 + + filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=True) + aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', + align_corners=False, antialias=False) + ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered + + return ada_filtered_64 + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + disc_c_noise=0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + self.disc_c_noise = disc_c_noise + + def forward(self, img, c, update_emas=False, **block_kwargs): + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + img = torch.cat([img['image'], image_raw], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DummyDualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + + self.raw_fade = 1 + + def forward(self, img, c, update_emas=False, **block_kwargs): + self.raw_fade = max(0, self.raw_fade - 1 / (500000 / 32)) + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], + f=self.resample_filter) * self.raw_fade + img = torch.cat([img['image'], image_raw], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +# ---------------------------------------------------------------------------- +from training.networks_stylegan2 import FullyConnectedLayer + + +class PoseShapeAwareDualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + seg_channels, # Number of input color channels. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. + disc_c_noise=0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. + explicitly_symmetry=False, + block_kwargs={}, # Arguments for DiscriminatorBlock. + mapping_kwargs={}, # Arguments for MappingNetwork. + epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels = img_channels * 2 + seg_channels + self.camera_param_dim = c_dim + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.pose_branch = DPoseBranch(num_betas=10, in_channel=channels_dict[4]*4*4) + self.c_dim += self.pose_branch.output_dim + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if self.c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, + **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if self.c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, + **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, + **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + self.disc_c_noise = disc_c_noise + + self.explicitly_symmetry = explicitly_symmetry + + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + + def predict_pose(self, img, c,update_emas=False, **block_kwargs): + + + if self.explicitly_symmetry: + theta = torch.atan2(c[:, [11]], c[:, [3]]) # math.atan2(z, x) + is_left = (theta >= -np.pi / 2) & (theta <= np.pi / 2) + + img_flip = torch.flip(img['image'], dims=[3]) + img_flip_raw = torch.flip(img['image_raw'], dims=[3]) + seg_flip = torch.flip(img['image_mask'], dims=[3]) + + is_left_img = is_left.unsqueeze(2).unsqueeze(3) + input_img = torch.where(is_left_img, img_flip, img['image']) # if left, flip image + misc.assert_shape(input_img, img_flip.shape ) + + is_left_img_raw = is_left.unsqueeze(2).unsqueeze(3) + input_img_raw = torch.where(is_left_img_raw, img_flip_raw, img['image_raw']) # if left, flip image_raw + misc.assert_shape(input_img_raw, img_flip_raw.shape ) + + is_left_seg = is_left.unsqueeze(2).unsqueeze(3) + input_seg = torch.where(is_left_seg, seg_flip, img['image_mask']) # if left, flip seg + misc.assert_shape(input_seg, seg_flip.shape ) + + img = {'image': input_img, 'image_raw': input_img_raw, 'image_mask': input_seg} + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + pose_params = self.pose_branch(pose_branch_input_feature, c) + + flip_pose_params = pose_params.clone() + flip_pose_params[:, [1, 2, 4, 5]] *= -1 # flip y and z axis angles + + pose_params = torch.where(is_left, flip_pose_params, pose_params) + + + else: + raise NotImplementedError + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + pose_params = self.pose_branch(pose_branch_input_feature, c) + + + return pose_params,pose_branch_input_feature + + def forward(self, img, c, gt_pose = None, update_emas=False, **block_kwargs): + + if self.explicitly_symmetry: + + pose_params,_ = self.predict_pose(img, c, update_emas, **block_kwargs) + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) + seg = 2 * seg - 1 # normalize to [-1,1] + img = torch.cat([img['image'], image_raw, seg], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + + pose_branch_input_feature = self.b4.get_flatten_x(x, img) + + else: + raise NotImplementedError + pose_params, pose_branch_input_feature = self.predict_pose(img, c, update_emas, **block_kwargs) + + if gt_pose is not None: + #raise NotImplementedError + c = torch.cat([c, gt_pose], dim=1) + else: + pose_label = pose_params.detach() # detach + c = torch.cat([c, pose_label], dim=1) + + cmap = None + if self.c_dim > 0: + if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise + cmap = self.mapping(None, c) + # x = self.b4(x, img, cmap) + x = self.b4(flatten_x=pose_branch_input_feature, cmap=cmap) + return x, pose_params + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + + +from torch_utils import misc + + +class DPoseBranch(torch.nn.Module): + def __init__(self, num_betas, in_channel): + super().__init__() + self.num_betas = num_betas + hidden_dim = 64 + self.in_channel = in_channel + # + # predict_betas = predict_transl = predict_scale = False + # predict_pose = True + + out_dim = 6 + + # if predict_betas: + # out_dim += num_betas + # if predict_transl: + # out_dim += 3 + # if predict_scale: + # out_dim += 1 + # if predict_pose: + # out_dim += 6 + + self.in_channel += 25 # c dim + + self.output_dim = out_dim + self.net = torch.nn.Sequential( + # linear + # FullyConnectedLayer(self.in_channel, hidden_dim), + # torch.nn.LeakyReLU(0.2), + # FullyConnectedLayer(hidden_dim, self.output_dim) # betas, scale, transl, rots of neck and head + FullyConnectedLayer(self.in_channel, 2048, activation='lrelu'), + FullyConnectedLayer(2048, 512, activation='lrelu'), + FullyConnectedLayer(512, 128, activation='lrelu'), + FullyConnectedLayer(128, 32, activation='lrelu'), + FullyConnectedLayer(32, self.output_dim) + ) + + + def forward(self, feature, camera_parameters): + # misc.assert_shape(feature, [None, self.in_channel]) + # misc.assert_shape(camera_parameters, [None, 25]) + feature = torch.cat([feature, camera_parameters], dim=1) + + pose = self.net(feature) # (B, num_betas + 1 + 3 + 6) + + return pose \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/loss.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/loss.py new file mode 100644 index 0000000..9a9cfaa --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/loss.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Loss functions.""" + +import numpy as np +import torch +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import upfirdn2d +from training.dual_discriminator import filtered_resizing +from torch_utils import misc +import copy + + +# ---------------------------------------------------------------------------- + +class Loss: + def accumulate_gradients(self, phase, real_img, real_seg, real_c, real_pose, gen_z, gen_c, gen_pose,gain, cur_nimg, + cur_nimg_start): # to be overridden by subclass + raise NotImplementedError() + + +# ---------------------------------------------------------------------------- + +class StyleGAN2Loss(Loss): + def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, r1_gamma_seg=1000,style_mixing_prob=0, pl_weight=0, + density_noise_fade_kimg=0, + pl_batch_shrink=2, pl_decay=0.01, + pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, + neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, + neural_rendering_resolution_fade_kimg=0, + gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased', + thickness=None, + pose_loss_weight = None, input_pose_params_reg_loss_weight = None,input_pose_params_reg_loss_kimg = None, + rank=None,bcg_reg_prob=0 + ): + super().__init__() + self.device = device + self.G = G + self.D = D + self.augment_pipe = augment_pipe + self.r1_gamma = r1_gamma + self.r1_gamma_seg = r1_gamma_seg + self.style_mixing_prob = style_mixing_prob + self.pl_weight = pl_weight + self.pl_batch_shrink = pl_batch_shrink + self.pl_decay = pl_decay + self.pl_no_weight_grad = pl_no_weight_grad + self.pl_mean = torch.zeros([], device=device) + self.blur_init_sigma = blur_init_sigma + self.blur_fade_kimg = blur_fade_kimg + self.r1_gamma_init = r1_gamma_init + self.r1_gamma_fade_kimg = r1_gamma_fade_kimg + self.neural_rendering_resolution_initial = neural_rendering_resolution_initial + self.neural_rendering_resolution_final = neural_rendering_resolution_final + self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg + self.density_noise_fade_kimg = density_noise_fade_kimg + self.gpc_reg_fade_kimg = gpc_reg_fade_kimg + self.gpc_reg_prob = gpc_reg_prob + self.dual_discrimination = dual_discrimination + self.filter_mode = filter_mode + self.resample_filter = upfirdn2d.setup_filter([1, 3, 3, 1], device=device) + self.blur_raw_target = True + self.bcg_reg_prob = bcg_reg_prob + assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1) + + + self.thickness = thickness + self.pose_loss_weight = pose_loss_weight + self.input_pose_params_reg_loss_weight = input_pose_params_reg_loss_weight + self.input_pose_params_reg_loss_kimg = input_pose_params_reg_loss_kimg + + + # for snap + self.swapping_prob = None + self.neural_rendering_resolution = None + self.blur_sigma = None + + + self.rank = rank + + def run_G(self, z, c, pose_params, swapping_prob, neural_rendering_resolution, update_emas=False): + if swapping_prob is not None: + c_swapped = torch.roll(c.clone(), 1, 0) + p_swapped = torch.roll(pose_params.clone(), 1, 0) + rand_ = torch.rand((c.shape[0], 1), device=c.device) + c_gen_conditioning = torch.where(rand_ < swapping_prob, c_swapped, c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, pose_params) + else: + c_gen_conditioning = torch.zeros_like(c) + pose_params_conditioning = torch.zeros([c.shape[0],6]).to(c.device) + + ws = self.G.mapping(z, c_gen_conditioning, pose_params_conditioning,update_emas=update_emas) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c,pose_params, update_emas=False)[:, cutoff:] + + if self.bcg_reg_prob > 0: + ws_swapped = torch.roll(ws.clone(), 1, 0) + ws_bcg = torch.where(torch.rand((ws.shape[0], 1, 1), device=ws.device) < self.bcg_reg_prob, ws_swapped, ws) + else: + ws_bcg = None + + + gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, + update_emas=update_emas, + apply_def=True, pose_params=pose_params,ws_bcg = ws_bcg + ) + return gen_output, ws + + + + def run_D(self, img, c, gt_pose=None, blur_sigma=0, blur_sigma_raw=0, update_emas=False): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + + if self.augment_pipe is not None: + raise NotImplementedError + augmented_pair = self.augment_pipe(torch.cat([img['image'], + torch.nn.functional.interpolate(img['image_raw'], + size=img['image'].shape[2:], + mode='bilinear', + antialias=True)], + dim=1)) + img['image'] = augmented_pair[:, :img['image'].shape[1]] + img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], + size=img['image_raw'].shape[2:], mode='bilinear', + antialias=True) + + logits, pose = self.D(img, c, gt_pose=gt_pose, update_emas=update_emas) + return logits, pose + + def run_D_pose_prediction(self, img, c, blur_sigma=0): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( + blur_sigma).square().neg().exp2() + img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) + + if self.augment_pipe is not None: + augmented_pair = self.augment_pipe(torch.cat([img['image'], + torch.nn.functional.interpolate(img['image_raw'], + size=img['image'].shape[2:], + mode='bilinear', + antialias=True)], + dim=1)) + img['image'] = augmented_pair[:, :img['image'].shape[1]] + img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], + size=img['image_raw'].shape[2:], mode='bilinear', + antialias=True) + + pose, _ = self.D.predict_pose(img, c) + return pose + + def get_pose_params_D(self, real_img, real_seg, real_c, cur_nimg): + blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), + 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 + r1_gamma = self.r1_gamma + + alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 + swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None + + if not isinstance(real_img,dict): + if self.neural_rendering_resolution_final is not None: + alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) + neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * ( + 1 - alpha) + self.neural_rendering_resolution_final * alpha)) + else: + neural_rendering_resolution = self.neural_rendering_resolution_initial + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + if self.blur_raw_target: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + else: + assert 'image_raw' in real_img.keys(), 'image_raw is not in real_img.keys()' + assert 'image' in real_img.keys(), 'image is not in real_img.keys()' + + + # get pose_params from real image + real_img_tmp_image = real_img['image'].detach().requires_grad_(True) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(True) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(True) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, 'image_mask': real_img_tmp_image_mask} + + predicted_real_pose = self.run_D_pose_prediction(real_img_tmp, real_c, blur_sigma=blur_sigma) + return predicted_real_pose + + def get_pose_params_G(self,z,c): + predicted_pose = self.G.get_pose_params(z,c) + return predicted_pose + + def accumulate_gradients(self, phase, real_img, real_seg, real_c, real_pose, + gen_z, gen_c,gen_pose, + gain, cur_nimg, cur_nimg_start): + assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] + if self.G.rendering_kwargs.get('density_reg', 0) == 0: + phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) + if self.r1_gamma == 0: + phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) + blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), + 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 + self.blur_sigma = blur_sigma + r1_gamma = self.r1_gamma + self.G.rendering_kwargs["density_noise"] = max(1 - cur_nimg / (self.density_noise_fade_kimg * 1e3), + 0) if self.density_noise_fade_kimg > 0 else 0 + + alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 + swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None + self.swapping_prob = swapping_prob + + if self.neural_rendering_resolution_final is not None: + alpha = min((cur_nimg-cur_nimg_start) / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) + neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * ( + 1 - alpha) + self.neural_rendering_resolution_final * alpha)) + else: + neural_rendering_resolution = self.neural_rendering_resolution_initial + + self.neural_rendering_resolution = neural_rendering_resolution + + real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + real_seg_raw = filtered_resizing(real_seg, size=neural_rendering_resolution, f=self.resample_filter, + filter_mode=self.filter_mode) + + + if self.blur_raw_target: + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( + blur_sigma).square().neg().exp2() + real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) + + real_img = {'image': real_img, 'image_raw': real_img_raw, 'image_mask': real_seg_raw} + + + input_pose_params = self.get_pose_params_G(gen_z,gen_c) + + + for i in range(input_pose_params.shape[1]): + training_stats.report('pose_scale/input_pose_params_{}'.format(i), + (input_pose_params[:, i]).abs().mean() / np.pi * 180) + + + # Gmain: Maximize logits for generated images. + if phase in ['Gmain', 'Gboth']: + with torch.autograd.profiler.record_function('Gmain_forward'): + gen_img, _gen_ws = self.run_G(gen_z, gen_c, input_pose_params, swapping_prob=swapping_prob, + neural_rendering_resolution=neural_rendering_resolution) + + + gen_logits, predict_gen_pose = self.run_D(gen_img, gen_c, gt_pose=None, blur_sigma=blur_sigma) + training_stats.report('Loss/scores/fake_posed', gen_logits) + training_stats.report('Loss/signs/fake_posed', gen_logits.sign()) + loss_Gmain = torch.nn.functional.softplus(-gen_logits) + + # Lpreg + if self.input_pose_params_reg_loss_weight>0 and cur_nimg<(self.input_pose_params_reg_loss_kimg+200) * 1e3: + + if cur_nimg 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'l1': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_L1', TVloss) + TVloss.mul(gain).backward() + + # Alternative density regularization + if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'monotonic-detach': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + + initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front + + perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10 + monotonic_loss.mul(gain).backward() + + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning,pose_params_conditioning, update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_monotonic-detach', TVloss) + TVloss.mul(gain).backward() + + # Alternative density regularization + if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs[ + 'reg_type'] == 'monotonic-fixed': + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + + initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front + + perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10 + monotonic_loss.mul(gain).backward() + + if swapping_prob is not None: + # c_swapped = torch.roll(gen_c.clone(), 1, 0) + # c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) + c_swapped = torch.roll(gen_c.clone(), 1, 0) + p_swapped = torch.roll(input_pose_params.clone(), 1, 0) + rand_ = torch.rand([], device=gen_c.device) + c_gen_conditioning = torch.where( rand_< swapping_prob, c_swapped, gen_c) + pose_params_conditioning = torch.where(rand_ < swapping_prob, p_swapped, input_pose_params) + else: + c_gen_conditioning = torch.zeros_like(gen_c) + pose_params_conditioning = torch.zeros([gen_c.shape[0],6]).to(gen_c.device) + + + ws = self.G.mapping(gen_z, c_gen_conditioning, pose_params_conditioning,update_emas=False) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, + torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, input_pose_params,update_emas=False)[:, cutoff:] + initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)[ + 'sigma'] + sigma_initial = sigma[:, :sigma.shape[1] // 2] + sigma_perturbed = sigma[:, sigma.shape[1] // 2:] + + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ + 'density_reg'] + training_stats.report('Loss/G_reg/TVloss_monotonic-fixed', TVloss) + TVloss.mul(gain).backward() + + # Dmain: Minimize logits for generated images. + loss_Dgen = 0 + if phase in ['Dmain', 'Dboth']: + with torch.autograd.profiler.record_function('Dgen_forward'): + + gen_img, _gen_ws = self.run_G(gen_z, gen_c, input_pose_params, swapping_prob=swapping_prob, + neural_rendering_resolution=neural_rendering_resolution, update_emas=True) + gen_logits, predict_gen_pose = self.run_D(gen_img, gen_c, gt_pose=None, blur_sigma=blur_sigma, + update_emas=True) + + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + loss_Dgen = torch.nn.functional.softplus( gen_logits) # -log (1 - sigmoid(gen_logits)) = log (1 + exp(gen_logits)) = softplus(gen_logits) + + pose_param_loss = (predict_gen_pose - input_pose_params).square().sum([1]) * self.pose_loss_weight + training_stats.report('Loss/D/Poseloss', pose_param_loss) + + for i in range(predict_gen_pose.shape[1]): + training_stats.report('Loss_pose/fake_{}'.format(i), + (predict_gen_pose[:, i] - input_pose_params[:, i]).abs().mean() / np.pi * 180) + training_stats.report('pose_scale/fake_{}'.format(i), + (predict_gen_pose[:, i]).abs().mean() / np.pi * 180) + + + + + with torch.autograd.profiler.record_function('Dgen_backward'): + (loss_Dgen + pose_param_loss).mean().mul(gain).backward() + + + # Dmain: Maximize logits for real images. + # Dr1: Apply R1 regularization. + if phase in ['Dmain', 'Dreg', 'Dboth']: + name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' + with torch.autograd.profiler.record_function(name + '_forward'): + real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp_image_mask = real_img['image_mask'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw, 'image_mask': real_img_tmp_image_mask} + + real_logits, predicted_real_pose = self.run_D(real_img_tmp, real_c, + gt_pose=None, + blur_sigma=blur_sigma) + + training_stats.report('Loss/scores/real', real_logits) + training_stats.report('Loss/signs/real', real_logits.sign()) + + + for i in range(predicted_real_pose.shape[1]): + training_stats.report('Loss_pose/real_{}'.format(i), ( + predicted_real_pose[:, i] - real_pose[:, i]).abs().mean() / np.pi * 180) + training_stats.report('pose_scale/real_{}'.format(i), + (predicted_real_pose[:, i]).abs().mean() / np.pi * 180) + + + loss_Dreal = 0 + if phase in ['Dmain', 'Dboth']: + loss_Dreal = torch.nn.functional.softplus( + -real_logits) # - log sigmoid(real_logits) = log (1 + exp(-real_logits)) = softplus(-real_logits) + training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) + training_stats.report('Loss/D/loss_gen', loss_Dgen) + training_stats.report('Loss/D/loss_real', loss_Dreal) + + + # + + loss_Dr1 = 0 + if phase in ['Dreg', 'Dboth']: + if self.dual_discrimination: + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], + inputs=[real_img_tmp['image'], real_img_tmp['image_raw'], real_img_tmp['image_mask']], + create_graph=True, only_inputs=True) + r1_grads_image = r1_grads[0] + r1_grads_image_raw = r1_grads[1] + r1_grads_image_mask = r1_grads[2] + r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3]) + r1_penalty_seg = r1_grads_image_mask.square().sum([1, 2, 3]) + else: # single discrimination + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_mask']], + create_graph=True, only_inputs=True) + r1_grads_image = r1_grads[0] + r1_grads_image_mask = r1_grads[1] + r1_penalty = r1_grads_image.square().sum([1, 2, 3]) + r1_penalty_seg = r1_grads_image_mask.square().sum([1, 2, 3]) + loss_Dr1 = r1_penalty * (self.r1_gamma / 2) + r1_penalty_seg * (self.r1_gamma_seg / 2) + training_stats.report('Loss/r1_penalty', r1_penalty) + training_stats.report('Loss/r1_penalty_seg', r1_penalty_seg) + training_stats.report('Loss/D/reg', loss_Dr1) + + + with torch.autograd.profiler.record_function(name + '_backward'): + (loss_Dreal + loss_Dr1).mean().mul(gain).backward() + +# ---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan2.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan2.py new file mode 100644 index 0000000..c3189f5 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan2.py @@ -0,0 +1,1131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Network architectures from the paper +"Analyzing and Improving the Image Quality of StyleGAN". +Matches the original implementation of configs E-F by Karras et al. at +https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" + + +import numpy as np +import torch +from nerf.torch_utils import misc +from nerf.torch_utils import persistence +from nerf.torch_utils.ops import conv2d_resample +from nerf.torch_utils.ops import upfirdn2d +from nerf.torch_utils.ops import bias_act +from nerf.torch_utils.ops import fma + + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + +# ---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise=None, # Optional noise tensor to add to the output activations. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + padding=0, # Padding with respect to the upsampled image. + resample_filter=None, + # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate=True, # Apply weight demodulation? + flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3], + keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, + padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, + groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, + padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers=8, # Number of mapping layers. + embed_features=None, # Label embedding dimensionality, None = same as w_dim. + layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + use_noise=True, # Enable noise input? + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last=False, # Use channels_last format for the weights? + roll_out=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.roll_out = roll_out + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + affine_scale = 1 + if self.roll_out in ['b', 'a']: + affine_scale = 9 + elif self.roll_out in ['s']: + affine_scale = 3 + self.affine = FullyConnectedLayer(w_dim, in_channels * affine_scale, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn( + [out_channels, in_channels * (1, 3)[self.roll_out in ['b', 'a']], + kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution * (1, 3)[self.roll_out == 'w']])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, **_): + assert noise_mode in ['random', 'const', 'none'] + # noise_mode = 'const' + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution * (1, 3)[self.roll_out == 'w']]) + styles = self.affine(w) + if self.roll_out in ['b', 'a', 's']: + styles = styles.view(styles.shape[0], 3, styles.shape[1] // 3).view(styles.shape[0] * 3, + styles.shape[1] // 3) + if self.roll_out in ['b', 'a', ]: + x = aware3d_att(x) if self.roll_out == 'a' else aware3d(x) + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution * (1, 3)[self.roll_out == 'w']], + device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, + fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + + +def aware3d(x): + if isinstance(x, list): + x_xy, x_yz, x_zx = x + B, _, H, W = x_xy.shape + B *= 3 + else: + x_ = x.view(-1, 3, x.shape[1], x.shape[2], x.shape[3]) + x_xy, x_yz, x_zx = x_[:, 0], x_[:, 1], x_[:, 2] + B, _, H, W = x.shape + x_zy = x_yz.permute(0, 1, 3, 2) + x_xz = x_zx.permute(0, 1, 3, 2) + x_yx = x_xy.permute(0, 1, 3, 2) + + x_zy_pz = x_zy.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_xy.shape[-1]) + x_xz_pz = x_xz.mean(dim=-2, keepdim=True).repeat(1, 1, x_xy.shape[-2], 1) + x_xy_ = torch.cat([x_xy, x_zy_pz, x_xz_pz], 1) + + x_yx_px = x_yx.mean(dim=-2, keepdim=True).repeat(1, 1, x_yz.shape[-2], 1) + x_xz_px = x_xz.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_yz.shape[-1]) + x_yz_ = torch.cat([x_yx_px, x_yz, x_xz_px], 1) + + x_yx_py = x_yx.mean(dim=-1, keepdim=True).repeat(1, 1, 1, x_zx.shape[-1]) + x_zy_py = x_zy.mean(dim=-2, keepdim=True).repeat(1, 1, x_zx.shape[-2], 1) + x_zx_ = torch.cat([x_yx_py, x_zy_py, x_zx], 1) + + x = torch.cat([x_xy_[:, None], x_yz_[:, None], x_zx_[:, None]], 1).view(B, -1, H, W) + return x + + +def aware3d_att(x): + x_ = x.view(-1, 3, x.shape[1], x.shape[2], x.shape[3]) + x_cyx, x_czy, x_cxz = x_[:, 0], x_[:, 1], x_[:, 2] + + x_yxc = x_cyx.permute(0, 2, 3, 1) + x_ycz = x_czy.permute(0, 3, 1, 2) + x_yzc = x_czy.permute(0, 3, 2, 1) + x_yxz = torch.einsum('byxc,bycz->byxz', x_yxc, x_ycz) + x_yxz = torch.softmax(x_yxz, dim=-1) + x_cyx_f_czy = torch.einsum('byxz,byzc->byxc', x_yxz, x_yzc).permute(0, 3, 1, 2) + x_xyc = x_cyx.permute(0, 3, 2, 1) + x_xcz = x_cxz.permute(0, 2, 1, 3) + x_xzc = x_cxz.permute(0, 2, 3, 1) + x_xyz = torch.einsum('bxyc,bxcz->bxyz', x_xyc, x_xcz) + x_xyz = torch.softmax(x_xyz, dim=-1) + x_cyx_f_cxz = torch.einsum('bxyz,bxzc->bxyc', x_xyz, x_xzc).permute(0, 3, 2, 1) + x_cyx_ = torch.cat([x_cyx, x_cyx_f_czy, x_cyx_f_cxz], 1) + + x_zyc = x_czy.permute(0, 2, 3, 1) + x_zcx = x_cxz.permute(0, 3, 1, 2) + x_zxc = x_cxz.permute(0, 3, 2, 1) + x_zyx = torch.einsum('bzyc,bzcx->bzyx', x_zyc, x_zcx) + x_zyx = torch.softmax(x_zyx, dim=-1) + x_czy_f_cxz = torch.einsum('bzyx,bzxc->bzyc', x_zyx, x_zxc).permute(0, 3, 1, 2) + x_ycx = x_cyx.permute(0, 2, 1, 3) + x_yzx = torch.einsum('byzc,bycx->byzx', x_yzc, x_ycx) + x_yzx = torch.softmax(x_yzx, dim=-1) + x_czy_f_cyx = torch.einsum('byzx,byxc->byzc', x_yzx, x_yxc).permute(0, 3, 2, 1) + x_czy_ = torch.cat([x_czy, x_czy_f_cxz, x_czy_f_cyx], 1) + + x_xcy = x_cyx.permute(0, 3, 1, 2) + x_xzy = torch.einsum('bxzc,bxcy->bxzy', x_xzc, x_xcy) + x_xzy = torch.softmax(x_xzy, dim=-1) + x_cxz_f_cyx = torch.einsum('bxzy,bxyc->bxzc', x_xzy, x_xyc).permute(0, 3, 1, 2) + x_zcy = x_czy.permute(0, 2, 1, 3) + x_zxy = torch.einsum('bzxc,bzcy->bzxy', x_zxc, x_zcy) + x_zxy = torch.softmax(x_zxy, dim=-1) + x_cxz_f_czy = torch.einsum('bzxy,bzyc->bzxc', x_zxy, x_zyc).permute(0, 3, 2, 1) + x_cxz_ = torch.cat([x_cxz, x_cxz_f_cyx, x_cxz_f_czy], 1) + + x = torch.cat([x_cyx_[:, None], x_czy_[:, None], x_cxz_[:, None]], 1).view(x.shape[0], -1, x.shape[2], x.shape[3]) + return x + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False, + roll_out=None): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.roll_out = roll_out + + affine_scale = 1 + if self.roll_out in ['b', 'a']: + affine_scale = 9 + elif self.roll_out in ['s']: + affine_scale = 3 + self.affine = FullyConnectedLayer(w_dim, in_channels * affine_scale, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn( + [out_channels, in_channels * (1, 3)[self.roll_out in ['b', 'a']], + kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + if self.roll_out in ['b', 'a', 's']: + styles = styles.view(styles.shape[0], 3, styles.shape[1] // 3).view(styles.shape[0] * 3, + styles.shape[1] // 3) + if self.roll_out in ['b', 'a', ]: + x = aware3d_att(x) if self.roll_out == 'a' else aware3d(x) + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + up=2, + architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + fused_modconv_default=True, + # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + roll_out=None, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.up = up + self.roll_out = roll_out + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([(1, 3)[self.roll_out in ['b', 'a']], out_channels, resolution, + resolution * (1, 3)[self.roll_out == 'w']])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=self.up, + roll_out=roll_out, + resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, roll_out=roll_out, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last, roll_out=self.roll_out) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // self.up, + self.resolution // self.up * (1, 3)[self.roll_out == 'w']]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None and self.up > 1: + misc.assert_shape(img, [None, self.img_channels, self.resolution // self.up, + self.resolution // self.up * (1, 3)[self.roll_out == 'w']]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class Hierarchy3DAwareSynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base=32768, # Overall multiplier for the number of channels. + channel_max=512, # Maximum number of channels in any layer. + num_fp16_res=4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + + aware3d_att=False + aware3d_res = [4,8,16,32,64,128,256] + add_block = 0 + + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.add_block = add_block + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + # channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = img_channels if res > 4 else 0 + out_channels = img_channels + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) and self.add_block == 0 + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + if res in aware3d_res: + block3d = Aware3DBlock(img_channels, res, w_dim, aware3d_att, + block_kwargs.copy()) + setattr(self, f'b3d{res}', block3d) + + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = img3d = None + feature_maps = {} + last_has_block3d = False + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + block3d = getattr(self, f'b3d{res}', None) + if last_has_block3d and block3d is None: + assert NotImplementedError + img = img3d.view(-1, 3, img3d.shape[-3], img.shape[-2], img.shape[-1]).view(img.shape) + x, img = block(x, img, cur_ws, **block_kwargs) + if block3d is not None: + last_has_block3d = True + img3d = block3d(img3d, img, cur_ws, block_kwargs) + if isinstance(img3d, list): + assert NotImplementedError + else: + feature_maps[res] = img3d + else: + assert NotImplementedError + + return feature_maps + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + + +@persistence.persistent_class +class SR3DBlock(torch.nn.Module): + def __init__(self, img_channels, img_resolution, w_dim, block_kwargs): + super().__init__() + block_kwargs['roll_out'] = 's' + self.block2 = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * 2, + up=2, + img_channels=32, is_last=True, use_fp16=False, **block_kwargs) + self.block3 = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * 2, + up=1, + img_channels=32, is_last=True, use_fp16=False, **block_kwargs) + + def forward(self, img, ws): + ws = ws[:, -1:, :].repeat(1, 3, 1) + img = img.view(img.shape[0], 3, -1, img.shape[-2], img.shape[-1]).view(img.shape[0] * 3, -1, img.shape[-2], + img.shape[-1]) + x, img2 = self.block2(img, None, ws) + x, img3 = self.block3(img2, None, ws) + img2 = img2.view(-1, 3, img2.shape[-3], img2.shape[-2], img2.shape[-1]).view(-1, 3 * img2.shape[-3], + img2.shape[-2], img2.shape[-1]) + img3 = img3.view(-1, 3, img3.shape[-3], img3.shape[-2], img.shape[-1]).view(-1, 3 * img3.shape[-3], + img3.shape[-2], img3.shape[-1]) + + return [img2, img3] + + +# ---------------------------------------------------------------------------- +@persistence.persistent_class +class Aware3DBlock(torch.nn.Module): + + def __init__(self, img_channels, img_resolution, w_dim, aware3d_att, block_kwargs): + super().__init__() + block_kwargs['roll_out'] = ('b', 'a')[aware3d_att] + up = 2 + self.block = SynthesisBlock(img_channels // 3, img_channels // 3, w_dim=w_dim, resolution=img_resolution * up, + up=up, + img_channels=img_channels // 3, is_last=True, use_fp16=False, **block_kwargs) + + def forward(self, x, img, ws, block_kwargs): + img = img.view(img.shape[0], 3, -1, img.shape[-2], img.shape[-1]).view(img.shape[0] * 3, -1, img.shape[-2], + img.shape[-1]) + if x is not None: + img = img + x + + ws = ws[:, -1:, :].repeat(1, 3, 1) + _, img = self.block(img, None, ws, **block_kwargs) + return img + + +@persistence.persistent_class +class Hierarchy3DAwareGenerator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs={}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = Hierarchy3DAwareSynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, + **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16=False, # Use FP16 for this block? + fp16_channels_last=False, # Use channels-last memory format with FP16? + freeze_layers=0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.num_layers = 0 + + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + + trainable_iter = trainable_gen() + + if in_channels == 0 or architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, + channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, + channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, + channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + if (x if x is not None else img).device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0 or self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + x = self.conv1(x) + + assert x.dtype == dtype + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + + +#---------------------------------------------------------------------------- + + +@persistence.persistent_class +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, + W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + def extra_repr(self): + return f'group_size={self.group_size}, num_channels={self.num_channels:d}' + + +# ---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) + + + def get_flatten_x(self, x, img, force_fp32=False): + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + if self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + x = x + self.fromrgb(img) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + + flatten_x = x.flatten(1) + + return flatten_x + + def forward(self, flatten_x, cmap, force_fp32=False): + # misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + # _ = force_fp32 # unused + # dtype = torch.float32 + # memory_format = torch.contiguous_format + # + # # FromRGB. + # x = x.to(dtype=dtype, memory_format=memory_format) + # if self.architecture == 'skip': + # misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + # img = img.to(dtype=dtype, memory_format=memory_format) + # x = x + self.fromrgb(img) + # + # # Main layers. + # if self.mbstd is not None: + # x = self.mbstd(x) + # x = self.conv(x) + + misc.assert_shape(flatten_x, [None, self.in_channels * self.resolution * self.resolution]) + dtype = torch.float32 + + x = self.fc(flatten_x) + x = self.out(x) + + # Conditioning. + if self.cmap_dim > 0: + misc.assert_shape(cmap, [None, self.cmap_dim]) + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan3.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan3.py new file mode 100644 index 0000000..40e5508 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/networks_stylegan3.py @@ -0,0 +1,517 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Generator architecture from the paper +"Alias-Free Generative Adversarial Networks".""" + +import numpy as np +import scipy.signal +import scipy.optimize +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import filtered_lrelu +from torch_utils.ops import bias_act + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor: [batch_size, in_channels, in_height, in_width] + w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] + s, # Style tensor: [batch_size, in_channels] + demodulate = True, # Apply weight demodulation? + padding = 0, # Padding: int or [padH, padW] + input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] +): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(x.shape[0]) + out_channels, in_channels, kh, kw = w.shape + misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(s, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs. + if demodulate: + w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() + s = s * s.square().mean().rsqrt() + + # Modulate weights. + w = w.unsqueeze(0) # [NOIkk] + w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Demodulate weights. + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Apply input scaling. + if input_gain is not None: + input_gain = input_gain.expand(batch_size, in_channels) # [NI] + w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Execute as one fused op using grouped convolution. + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) + x = x.reshape(batch_size, -1, *x.shape[2:]) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + bias = True, # Apply additive bias before the activation function? + lr_multiplier = 1, # Learning rate multiplier. + weight_init = 1, # Initial standard deviation of the weight tensor. + bias_init = 0, # Initial value of the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) + bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) + self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality, 0 = no labels. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output. + num_layers = 2, # Number of mapping layers. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + # Construct layers. + self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None + features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + misc.assert_shape(z, [None, self.z_dim]) + if truncation_cutoff is None: + truncation_cutoff = self.num_ws + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = self.embed(c.to(torch.float32)) + y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() + x = torch.cat([x, y], dim=1) if x is not None else y + + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast and apply truncation. + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + if truncation_psi != 1: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisInput(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + channels, # Number of output channels. + size, # Output spatial size: int or [width, height]. + sampling_rate, # Output sampling rate. + bandwidth, # Output bandwidth. + ): + super().__init__() + self.w_dim = w_dim + self.channels = channels + self.size = np.broadcast_to(np.asarray(size), [2]) + self.sampling_rate = sampling_rate + self.bandwidth = bandwidth + + # Draw random frequencies from uniform 2D disc. + freqs = torch.randn([self.channels, 2]) + radii = freqs.square().sum(dim=1, keepdim=True).sqrt() + freqs /= radii * radii.square().exp().pow(0.25) + freqs *= bandwidth + phases = torch.rand([self.channels]) - 0.5 + + # Setup parameters and buffers. + self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) + self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) + self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. + self.register_buffer('freqs', freqs) + self.register_buffer('phases', phases) + + def forward(self, w): + # Introduce batch dimension. + transforms = self.transform.unsqueeze(0) # [batch, row, col] + freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] + phases = self.phases.unsqueeze(0) # [batch, channel] + + # Apply learned transformation. + t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) + m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. + m_r[:, 0, 0] = t[:, 0] # r'_c + m_r[:, 0, 1] = -t[:, 1] # r'_s + m_r[:, 1, 0] = t[:, 1] # r'_s + m_r[:, 1, 1] = t[:, 0] # r'_c + m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. + m_t[:, 0, 2] = -t[:, 2] # t'_x + m_t[:, 1, 2] = -t[:, 3] # t'_y + transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. + + # Transform frequencies. + phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) + + # Construct sampling grid. + theta = torch.eye(2, 3, device=w.device) + theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate + theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate + grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) + + # Compute Fourier features. + x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] + x = x + phases.unsqueeze(1).unsqueeze(2) + x = torch.sin(x * (np.pi * 2)) + x = x * amplitudes.unsqueeze(1).unsqueeze(2) + + # Apply trainable mapping. + weight = self.weight / np.sqrt(self.channels) + x = x @ weight.t() + + # Ensure correct shape. + x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] + misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', + f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + is_torgb, # Is this the final ToRGB layer? + is_critically_sampled, # Does this layer use critical sampling? + use_fp16, # Does this layer use FP16? + + # Input & output specifications. + in_channels, # Number of input channels. + out_channels, # Number of output channels. + in_size, # Input spatial size: int or [width, height]. + out_size, # Output spatial size: int or [width, height]. + in_sampling_rate, # Input sampling rate (s). + out_sampling_rate, # Output sampling rate (s). + in_cutoff, # Input cutoff frequency (f_c). + out_cutoff, # Output cutoff frequency (f_c). + in_half_width, # Input transition band half-width (f_h). + out_half_width, # Output Transition band half-width (f_h). + + # Hyperparameters. + conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. + filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. + lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. + use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. + conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. + magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. + ): + super().__init__() + self.w_dim = w_dim + self.is_torgb = is_torgb + self.is_critically_sampled = is_critically_sampled + self.use_fp16 = use_fp16 + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.conv_kernel = 1 if is_torgb else conv_kernel + self.conv_clamp = conv_clamp + self.magnitude_ema_beta = magnitude_ema_beta + + # Setup parameters and buffers. + self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) + self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) + self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) + self.register_buffer('magnitude_ema', torch.ones([])) + + # Design upsampling filter. + self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate + self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 + self.register_buffer('up_filter', self.design_lowpass_filter( + numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) + + # Design downsampling filter. + self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate + self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 + self.down_radial = use_radial_filters and not self.is_critically_sampled + self.register_buffer('down_filter', self.design_lowpass_filter( + numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) + + # Compute padding. + pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. + pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. + pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. + pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): + assert noise_mode in ['random', 'const', 'none'] # unused + misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) + misc.assert_shape(w, [x.shape[0], self.w_dim]) + + # Track input magnitude. + if update_emas: + with torch.autograd.profiler.record_function('update_magnitude_ema'): + magnitude_cur = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) + input_gain = self.magnitude_ema.rsqrt() + + # Execute affine layer. + styles = self.affine(w) + if self.is_torgb: + weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) + styles = styles * weight_gain + + # Execute modulated conv2d. + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, + padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) + + # Execute bias, filtered leaky ReLU, and clamping. + gain = 1 if self.is_torgb else np.sqrt(2) + slope = 1 if self.is_torgb else 0.2 + x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), + up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) + assert x.dtype == dtype + return x + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + assert numtaps >= 1 + + # Identity filter. + if numtaps == 1: + return None + + # Separable Kaiser low-pass filter. + if not radial: + f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return torch.as_tensor(f, dtype=torch.float32) + + # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', + f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', + f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', + f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', + f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', + f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. + num_critical = 2, # Number of critically sampled layers at the end. + first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). + first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). + last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size = 10, # Number of additional pixels outside the image. + output_scale = 0.25, # Scale factor for the output image. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.w_dim = w_dim + self.num_ws = num_layers + 2 + self.img_resolution = img_resolution + self.img_channels = img_channels + self.num_layers = num_layers + self.num_critical = num_critical + self.margin_size = margin_size + self.output_scale = output_scale + self.num_fp16_res = num_fp16_res + + # Geometric progression of layer cutoffs and min. stopbands. + last_cutoff = self.img_resolution / 2 # f_{c,N} + last_stopband = last_cutoff * last_stopband_rel # f_{t,N} + exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] + stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] + + # Compute remaining layer parameters. + sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] + sizes = sampling_rates + self.margin_size * 2 + sizes[-2:] = self.img_resolution + channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) + channels[-1] = self.img_channels + + # Construct layers. + self.input = SynthesisInput( + w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), + sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) + self.layer_names = [] + for idx in range(self.num_layers + 1): + prev = max(idx - 1, 0) + is_torgb = (idx == self.num_layers) + is_critically_sampled = (idx >= self.num_layers - self.num_critical) + use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) + layer = SynthesisLayer( + w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, + in_channels=int(channels[prev]), out_channels= int(channels[idx]), + in_size=int(sizes[prev]), out_size=int(sizes[idx]), + in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), + in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev], out_half_width=half_widths[idx], + **layer_kwargs) + name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' + setattr(self, name, layer) + self.layer_names.append(name) + + def forward(self, ws, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32).unbind(dim=1) + + # Execute layers. + x = self.input(ws[0]) + for name, w in zip(self.layer_names, ws[1:]): + x = getattr(self, name)(x, w, **layer_kwargs) + if self.output_scale != 1: + x = x * self.output_scale + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) + x = x.to(torch.float32) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', + f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/neural_render.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/neural_render.py new file mode 100644 index 0000000..7b28da5 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/neural_render.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch +from nerf.torch_utils import persistence +from nerf.trigrid_rendering.networks_stylegan2 import ToRGBLayer, FullyConnectedLayer + +from nerf.trigrid_rendering.volumetric_rendering.renderer import ImportanceRenderer +from nerf.trigrid_rendering.volumetric_rendering.ray_sampler import RaySampler +import numpy as np + +@persistence.persistent_class +class NeuralRender(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + rendering_kwargs={}, + batch_size=1, + thickness=0.05, + apply_deformation = False + ): + super().__init__() + + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.trigrid_channel = 12 + self.decode_channel = 32 + + self.batch_size = batch_size + self.renderer = ImportanceRenderer(w_dim=w_dim, num_ws=14, batch_size=self.batch_size, thickness=thickness, + box_warp=rendering_kwargs['box_warp'],apply_deformation = apply_deformation) # disable deformation for now + self.ray_sampler = RaySampler() + + self.decoder = OSGDecoder(self.trigrid_channel, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': self.decode_channel, + 'decoder_activation': rendering_kwargs['decoder_activation']}) + + self.torgb = ToRGBLayer(self.decode_channel, 3, w_dim) + + self.rendering_kwargs = rendering_kwargs + self.neural_rendering_resolution = 64 + + self.pose_branch = GPoseBranch(z_dim=z_dim, c_dim=c_dim) + + self.avg_c = torch.tensor([[1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().cuda() + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + def sample_pose_params(self, c): + assert len(c.shape) == 2 and c.shape[1] == 25 + # randomly sample z from Gaussian distribution + z = torch.randn(c.shape[0], self.z_dim).to(c.device) + + theta = torch.atan2(c[:, [11]], c[:, [3]]) # math.atan2(z, x) + is_left = (theta >= -np.pi / 2) & (theta <= np.pi / 2) + + flip_c = self.flip_yaw(c) + input_c = torch.where(is_left, flip_c, c) # if left, flip c + + pose_params = self.pose_branch(z, input_c) + + flip_pose_params = pose_params.clone() + flip_pose_params[:, [1, 2, 4, 5]] *= -1 # flip y and z axis angles + + pose_params = torch.where(is_left, flip_pose_params, pose_params) # if left, flip back pose_params + + return pose_params + + +class OSGDecoder(torch.nn.Module): + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 32 + + self.net = torch.nn.Sequential( + FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), + torch.nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], + lr_multiplier=options['decoder_lr_mul']) + ) + self.activation = options['decoder_activation'] + + + + + def forward(self, sampled_features, ray_directions): + # Aggregate features + sampled_features = sampled_features.mean(1) + x = sampled_features + + N, M, C = x.shape + x = x.view(N * M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = x[..., 1:] + sigma = x[..., 0:1] + + + if self.activation == "sigmoid": + # Original EG3D + rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 + elif self.activation == "lrelu": + # StyleGAN2-style, use with toRGB + rgb = torch.nn.functional.leaky_relu(rgb, 0.2, inplace=True) * math.sqrt(2) + return {'rgb': rgb, 'sigma': sigma} + +import numpy as np +class GPoseBranch(torch.nn.Module): + def __init__(self, z_dim, c_dim): + super().__init__() + hidden_dim = 64 + self.in_channel = z_dim + c_dim + # + # predict_betas = predict_transl = predict_scale = False + # predict_pose = True + + out_dim = 6 + + # if predict_betas: + # out_dim += num_betas + # if predict_transl: + # out_dim += 3 + # if predict_scale: + # out_dim += 1 + # if predict_pose: + # out_dim += 6 + + self.output_dim = out_dim + self.net = torch.nn.Sequential( + FullyConnectedLayer(self.in_channel, 128, activation='lrelu'), + FullyConnectedLayer(128, 32, activation='lrelu'), + FullyConnectedLayer(32, self.output_dim) + ) + + + def forward(self, z, c): + # misc.assert_shape(feature, [None, self.in_channel]) + # misc.assert_shape(camera_parameters, [None, 25]) + feature = torch.cat([z, c], dim=1) + + pose = self.net(feature) # (B, num_betas + 1 + 3 + 6) + + + return pose \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/smpl_triplane.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/smpl_triplane.py new file mode 100644 index 0000000..e162a0a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/smpl_triplane.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch +from torch_utils import persistence +from training.networks_stylegan2 import ToRGBLayer, SynthesisNetwork + +from training.networks_stylegan2 import Hierarchy3DAwareGenerator as StyleGAN2Backbone +from training.volumetric_rendering.renderer import ImportanceRenderer +from training.volumetric_rendering.ray_sampler import RaySampler +import dnnlib +import numpy as np + +@persistence.persistent_class +class TriPlaneGenerator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + sr_num_fp16_res=0, + mapping_kwargs={}, # Arguments for MappingNetwork. + rendering_kwargs={}, + sr_kwargs={}, + batch_size=1, + explicitly_symmetry=False, + thickness=0.05, + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + bcg_synthesis_kwargs = synthesis_kwargs.copy() + bcg_synthesis_kwargs["channel_base"] = 32768 + bcg_synthesis_kwargs["channel_max"] = 512 + + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + self.trigrid_channel = 12 + self.decode_channel = 32 + + self.batch_size = batch_size + self.renderer = ImportanceRenderer(w_dim = w_dim, num_ws = 14, batch_size = self.batch_size,thickness =thickness,box_warp = rendering_kwargs['box_warp']) + self.ray_sampler = RaySampler() + # self.backbone = StyleGAN2Backbone(z_dim, c_dim+6, w_dim, img_resolution=512, img_channels=self.trigrid_channel*3*rendering_kwargs['triplane_depth'], mapping_kwargs=mapping_kwargs, **synthesis_kwargs) + self.backbone = StyleGAN2Backbone(z_dim, c_dim + 6, w_dim, img_resolution=256, + img_channels=self.trigrid_channel * 3 * rendering_kwargs['triplane_depth'], + mapping_kwargs=mapping_kwargs, roll_out=None, + **synthesis_kwargs) # forbid roll_out in main G + + self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=self.decode_channel, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) + self.decoder = OSGDecoder(self.trigrid_channel, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': self.decode_channel, + 'decoder_activation': rendering_kwargs['decoder_activation']}) + + self.torgb = ToRGBLayer(self.decode_channel, 3, w_dim) if rendering_kwargs.get('use_torgb_raw', False) else None + + self.bcg_synthesis = SynthesisNetwork(w_dim, img_resolution=self.superresolution.input_resolution, + img_channels=self.decode_channel, + **bcg_synthesis_kwargs) if rendering_kwargs.get('use_background', + False) else None + + self.pose_branch = GPoseBranch(z_dim=z_dim, c_dim=c_dim) + self.neural_rendering_resolution = 64 + self.rendering_kwargs = rendering_kwargs + + self._last_planes = None + + self.explicitly_symmetry = explicitly_symmetry + + self.avg_c = torch.tensor([[1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().cuda() + + self.plane_shapes = {} + + planes = self.backbone.synthesis(torch.zeros(4, self.backbone.synthesis.num_ws, w_dim), update_emas=False, + **synthesis_kwargs) + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + if res_k not in self.plane_shapes: + self.plane_shapes[res_k] = planes[res_k].shape + + def flip_yaw(self, matrix): + flipped_matrix = matrix.clone() + flipped = flipped_matrix[:, :16].reshape(-1, 4, 4) + flipped[:, 0, 1] *= -1 + flipped[:, 0, 2] *= -1 + flipped[:, 1, 0] *= -1 + flipped[:, 2, 0] *= -1 + flipped[:, 0, 3] *= -1 + + flipped = flipped.reshape(-1, 16) + flipped_matrix[:, :16] = flipped.clone() + + return flipped_matrix + + def get_pose_params(self, z, c): + if self.explicitly_symmetry: + # check if c is a left face + theta = torch.atan2(c[:, [11]], c[:, [3]]) # math.atan2(z, x) + is_left = (theta >= -np.pi / 2) & (theta <= np.pi / 2) + + flip_c = self.flip_yaw(c) + input_c = torch.where(is_left, flip_c, c) # if left, flip c + + pose_params = self.pose_branch(z, input_c) + + flip_pose_params = pose_params.clone() + flip_pose_params[:, [1, 2, 4, 5]] *= -1 # flip y and z axis angles + + pose_params = torch.where(is_left, flip_pose_params, pose_params) # if left, flip back pose_params + + return pose_params + else: + raise NotImplementedError + return self.pose_branch(z, c) + + def set_batch_size(self, batch_size): + self.renderer.set_batch_size(batch_size) + + def render_meshes(self, shape_pose_params, resolution, cameras): + + return self.renderer.render_meshes(shape_pose_params, resolution, cameras) + + def mapping(self, z, c, p, truncation_psi=1, truncation_cutoff=None, update_emas=False): + if self.rendering_kwargs['c_gen_conditioning_zero']: + raise NotImplementedError + p = torch.zeros([c.shape[0], 6]).to(c.device) + c = self.avg_c.repeat(c.shape[0], 1).to(c.device) + c = torch.cat([c, p], dim=1) + + else: + + if p is None: + p = torch.zeros([c.shape[0], 6]).to(c.device) + c = torch.cat([c, p], dim=1) + return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, update_emas=update_emas) + + def synthesis(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, + use_cached_backbone=False, + apply_def=False, pose_params=None, ws_bcg=None, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + intrinsics = c[:, 16:25].view(-1, 3, 3) + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) + + # Create triplanes by running StyleGAN backbone + N, M, _ = ray_origins.shape + if use_cached_backbone and self._last_planes is not None: + planes = self._last_planes + else: + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + + if cache_backbone: + self._last_planes = planes + + # Reshape output into three D*32-channel planes, where D=self.rendering_kwargs['triplane_depth'], defines the depth of the tri-grid + for res_k in planes: + # b, c, H,W + # planes[res_k] = planes[res_k].view(len(planes[res_k]), 3, -1, planes[res_k].shape[-2], planes[res_k].shape[-1]) + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + # Perform volume rendering + render_output = self.renderer(planes, self.decoder, ray_origins, + ray_directions, self.rendering_kwargs, apply_def=apply_def, ws=ws, + pose_params=pose_params) # channels last + # {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights.sum(2)} + feature_samples = render_output['rgb_final'] + depth_samples = render_output['depth_final'] + weights_samples = render_output['weights'] + + # Reshape into 'raw' neural-rendered image + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run superresolution to get final image + if self.decoder.activation == "sigmoid": + feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher + # Generate Background + if self.bcg_synthesis: + ws_bcg = ws[:, :self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:, :self.bcg_synthesis.num_ws] + if ws_bcg.size(1) < self.bcg_synthesis.num_ws: + ws_bcg = torch.cat([ws_bcg, ws_bcg[:, -1:].repeat(1, self.bcg_synthesis.num_ws - ws_bcg.size(1), 1)], 1) + bcg_image = self.bcg_synthesis(ws_bcg, update_emas=update_emas, **synthesis_kwargs) + bcg_image = torch.nn.functional.interpolate(bcg_image, size=feature_image.shape[2:], + mode='bilinear', align_corners=False, + antialias=self.rendering_kwargs['sr_antialias']) + feature_image = feature_image + (1 - weights_samples) * bcg_image + + # Generate Raw image + if self.torgb: + rgb_image = self.torgb(feature_image, ws[:, -1], fused_modconv=False) + rgb_image = rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + bcg_rgb_image = self.torgb(bcg_image, ws_bcg[:, -1], fused_modconv=False) + bcg_rgb_image = bcg_rgb_image.to(dtype=torch.float32, memory_format=torch.contiguous_format) + else: + rgb_image = feature_image[:, :3] + bcg_rgb_image = bcg_image[:, :3] + # Run superresolution to get final image + sr_image = self.superresolution(rgb_image, feature_image, ws, + noise_mode=self.rendering_kwargs['superresolution_noise_mode'], + **{k: synthesis_kwargs[k] for k in synthesis_kwargs.keys() if + k != 'noise_mode'}) + + mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 + + return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image, "image_mask": mask_image, + "image_background": bcg_rgb_image} + + def sample_ws(self, coordinates, directions, ws, update_emas=False, **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def sample(self, coordinates, directions, z, c, p, truncation_psi=1, truncation_cutoff=None, update_emas=False, + **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. + ws = self.mapping(z, c, p, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, + **synthesis_kwargs): + # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' + planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + # planes = planes.view(len(planes), 3, 32 * self.rendering_kwargs['triplane_depth'], planes.shape[-2], + # planes.shape[-1]) + for res_k in planes: + # b, c, H,W + planes[res_k] = planes[res_k].view(len(planes[res_k]) // 3, 3, planes[res_k].shape[-3], + planes[res_k].shape[-2], planes[res_k].shape[-1]) + return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, + update_emas=False, cache_backbone=False, use_cached_backbone=False, + apply_def=False, pose_params=None, + **synthesis_kwargs): + # Render a batch of generated images. + ws = self.mapping(z, c, pose_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + # TODO + return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, + cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, + apply_def=apply_def, pose_params=pose_params, + **synthesis_kwargs) + + +from training.networks_stylegan2 import FullyConnectedLayer + + +class OSGDecoder(torch.nn.Module): + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 32 + + self.net = torch.nn.Sequential( + FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), + torch.nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], + lr_multiplier=options['decoder_lr_mul']) + ) + self.activation = options['decoder_activation'] + + def forward(self, sampled_features, ray_directions): + # Aggregate features + sampled_features = sampled_features.mean(1) + x = sampled_features + + N, M, C = x.shape + x = x.view(N * M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = x[..., 1:] + sigma = x[..., 0:1] + if self.activation == "sigmoid": + # Original EG3D + rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 + elif self.activation == "lrelu": + # StyleGAN2-style, use with toRGB + rgb = torch.nn.functional.leaky_relu(rgb, 0.2, inplace=True) * math.sqrt(2) + return {'rgb': rgb, 'sigma': sigma} + + +import numpy as np + + +class GPoseBranch(torch.nn.Module): + def __init__(self, z_dim, c_dim): + super().__init__() + hidden_dim = 64 + self.in_channel = z_dim + c_dim + # + # predict_betas = predict_transl = predict_scale = False + # predict_pose = True + + out_dim = 6 + + # if predict_betas: + # out_dim += num_betas + # if predict_transl: + # out_dim += 3 + # if predict_scale: + # out_dim += 1 + # if predict_pose: + # out_dim += 6 + + self.output_dim = out_dim + self.net = torch.nn.Sequential( + FullyConnectedLayer(self.in_channel, 128, activation='lrelu'), + FullyConnectedLayer(128, 32, activation='lrelu'), + FullyConnectedLayer(32, self.output_dim) + ) + + def forward(self, z, c): + # misc.assert_shape(feature, [None, self.in_channel]) + # misc.assert_shape(camera_parameters, [None, 25]) + feature = torch.cat([z, c], dim=1) + + pose = self.net(feature) # (B, num_betas + 1 + 3 + 6) + + return pose \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/superresolution.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/superresolution.py new file mode 100644 index 0000000..43321df --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/superresolution.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Superresolution network architectures from the paper +"Efficient Geometry-aware 3D Generative Adversarial Networks".""" + +import torch +from training.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer +from torch_utils.ops import upfirdn2d +from torch_utils import persistence +from torch_utils import misc + +from training.networks_stylegan2 import SynthesisBlock +import numpy as np +from training.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer + + +#---------------------------------------------------------------------------- + +# for 512x512 generation +@persistence.persistent_class +class SuperresolutionHybrid8X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 128, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 256x256 generation +@persistence.persistent_class +class SuperresolutionHybrid4X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + self.sr_antialias = sr_antialias + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 128 x 128 generation +@persistence.persistent_class +class SuperresolutionHybrid2X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 128 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 64 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=64, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=128, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# TODO: Delete (here for backwards compatibility with old 256x256 models) +@persistence.persistent_class +class SuperresolutionHybridDeepfp32(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlockNoUp(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + # if img is not None: + # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + # img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +#---------------------------------------------------------------------------- + +# for 512x512 generation +@persistence.persistent_class +class SuperresolutionHybrid8XDC(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 256, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(256, 128, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False, antialias=self.sr_antialias) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/training_loop.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/training_loop.py new file mode 100644 index 0000000..681de57 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/training_loop.py @@ -0,0 +1,714 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +"""Main training loop.""" + +import os +import random +import time +import copy +import json +import pickle +import psutil +import PIL.Image +import numpy as np +import torch +import dnnlib +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +import legacy +from metrics import metric_main,metric_utils +from camera_utils import LookAtPoseSampler +from training.crosssection_utils import sample_cross_section + +#---------------------------------------------------------------------------- + +def setup_snapshot_image_grid(training_set, random_seed=0): + rnd = np.random.RandomState(random_seed) + h = int(7680 * (training_set.image_shape[2]/512)) + w = int(4320 * (training_set.image_shape[2] / 512)) + gh = np.clip(h // training_set.image_shape[2], 7, 8) + gw = np.clip(w // training_set.image_shape[1], 4, 4) + + # No labels => show random subset of training samples. + # if not training_set.has_labels: + # all_indices = list(range(len(training_set))) + # rnd.shuffle(all_indices) + # grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] + + # else: + # # Group training samples by label. + # label_groups = dict() # label => [idx, ...] + # for idx in range(len(training_set)): + # label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + # if label not in label_groups: + # label_groups[label] = [] + # label_groups[label].append(idx) + + # # Reorder. + # label_order = list(label_groups.keys()) + # rnd.shuffle(label_order) + # for label in label_order: + # rnd.shuffle(label_groups[label]) + + # # Organize into grid. + # grid_indices = [] + # for y in range(gh): + # label = label_order[y % len(label_order)] + # indices = label_groups[label] + # grid_indices += [indices[x % len(indices)] for x in range(gw)] + # label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + label_groups = dict() # label => [idx, ...] + for idx in range(len(training_set)): + label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + if label not in label_groups: + label_groups[label] = [] + label_groups[label].append(idx) + + # Reorder. + label_order = list(label_groups.keys()) + rnd.shuffle(label_order) + for label in label_order: + rnd.shuffle(label_groups[label]) + + # Organize into grid. + grid_indices = [] + for y in range(gh): + for x in range(gw//2): + label = label_order[(y + x*gh) % len(label_order)] + indices = list(set(label_groups[label])) + #grid_indices += [indices[x % len(indices)] for x in range(2)] + grid_indices += [indices[0], (indices[0]+ len(training_set)//2)%len(training_set) ] + label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + + + # Load data. + images, segs, labels, poses = zip(*[training_set[i] for i in grid_indices]) + return (gw, gh), np.stack(images),np.stack(segs), np.stack(labels), np.stack(poses) + +#---------------------------------------------------------------------------- + +def save_image_grid(img, fname, drange, grid_size): + lo, hi = drange + img = np.asarray(img, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = np.rint(img).clip(0, 255).astype(np.uint8) + + gw, gh = grid_size + _N, C, H, W = img.shape + img = img.reshape([gh, gw, C, H, W]) + img = img.transpose(0, 3, 1, 4, 2) + img = img.reshape([gh * H, gw * W, C]) + + assert C in [1, 3] + if C == 1: + PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) + if C == 3: + PIL.Image.fromarray(img, 'RGB').save(fname) + +#---------------------------------------------------------------------------- + +def training_loop( + run_dir = '.', # Output directory. + training_set_kwargs = {}, # Options for training set. + data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. + G_kwargs = {}, # Options for generator network. + D_kwargs = {}, # Options for discriminator network. + G_opt_kwargs = {}, # Options for generator optimizer. + D_opt_kwargs = {}, # Options for discriminator optimizer. + augment_kwargs = None, # Options for augmentation pipeline. None = disable. + loss_kwargs = {}, # Options for loss function. + metrics = [], # Metrics to evaluate during training. + random_seed = 0, # Global random seed. + num_gpus = 1, # Number of GPUs participating in the training. + rank = 0, # Rank of the current process in [0, num_gpus[. + batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. + batch_gpu = 4, # Number of samples processed at a time by one GPU. + ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. + ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. + G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. + D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. + augment_p = 0, # Initial value of augmentation probability. + ada_target = None, # ADA target value. None = fixed p. + ada_interval = 4, # How often to perform ADA adjustment? + ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. + total_kimg = 25000, # Total length of the training, measured in thousands of real images. + kimg_per_tick = 4, # Progress snapshot interval. + image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. + network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. + resume_pkl = None, # Network pickle to resume training from. + resume_kimg = 0, # First kimg to report when resuming training. + cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. + progress_fn = None, # Callback function for updating training progress. Called for all ranks. + train_g_pose_branch = None, + metric_pose_sample_mode = None, +): + print('Random seed: %d' % random_seed) + # Initialize. + start_time = time.time() + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.cuda.set_device(device) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. # TODO: ENABLE + grid_sample_gradfix.enabled = False # Avoids errors with the augmentation pipe. + + # Load training set. + if rank == 0: + print('Loading training set...') + training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset + training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) + training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) + if rank == 0: + print() + print('Num images: ', len(training_set)) + print('Image shape:', training_set.image_shape) + print('Label shape:', training_set.label_shape) + print('Pose shape:', training_set.pose_shape) + print() + print('>>>>>>>>>>>>>>> image_snapshot_ticks:', image_snapshot_ticks) + print('>>>>>>>>>>>>>>> network_snapshot_ticks:', network_snapshot_ticks) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) + G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G.register_buffer('dataset_label_std', torch.tensor(training_set.get_label_std()).to(device)) + D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G_ema = copy.deepcopy(G).eval() + D_ema = copy.deepcopy(D).eval() + + # Resume from existing pickle. + if (resume_pkl is not None) and (rank == 0): + print(f'Resuming from "{resume_pkl}"') + with dnnlib.util.open_url(resume_pkl) as f: + resume_data = legacy.load_network_pkl(f) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: + misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + + if 'D_ema' in resume_data: + print(f'copy params of D_ema of "{resume_pkl} to D_ema') + misc.copy_params_and_buffers(resume_data['D_ema'], D_ema, require_all=False) + else: + print(f'copy params of D of "{resume_pkl} to D_ema') + misc.copy_params_and_buffers(resume_data['D'], D_ema, require_all=False) + + # Print network summary tables. + if rank == 0: + z = torch.empty([batch_gpu, G.z_dim], device=device) + c = torch.empty([batch_gpu, G.c_dim], device=device) + p = torch.empty([batch_gpu, 6], device=device) + img = misc.print_module_summary(G, [z, c, ]) + misc.print_module_summary(D, [img, c ]) + + print('plane_shapes:') + for res_k in G.plane_shapes: + print(res_k, G.plane_shapes[res_k]) + # Setup augmentation. + if rank == 0: + print('Setting up augmentation...') + augment_pipe = None + ada_stats = None + if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): + augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + augment_pipe.p.copy_(torch.as_tensor(augment_p)) + if ada_target is not None: + ada_stats = training_stats.Collector(regex='Loss/signs/real') + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [G, D, G_ema,D_ema, augment_pipe]: + if module is not None: + for param in misc.params_and_buffers(module): + if param.numel() > 0 and num_gpus > 1: + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe,rank = rank,**loss_kwargs) # subclass of training.loss.Loss + phases = [] + for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: + params_list = [] + params_name_list = [] + for p_name, p in module.named_parameters(): + if name == 'G': + if 'aligned_SMPL' not in p_name: + if not train_g_pose_branch: + if 'pose_branch' not in p_name: + params_list.append(p) + params_name_list.append(p_name) + else: + params_list.append(p) + params_name_list.append(p_name) + else: + params_list.append(p) + params_name_list.append(p_name) + + + + if rank ==0: + print(f'params_name_list of {name}:',params_name_list) + + if reg_interval is None: + opt = dnnlib.util.construct_class_by_name(params=params_list, **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] + + + else: # Lazy regularization. + mb_ratio = reg_interval / (reg_interval + 1) + opt_kwargs = dnnlib.EasyDict(opt_kwargs) + opt_kwargs.lr = opt_kwargs.lr * mb_ratio + opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] + opt = dnnlib.util.construct_class_by_name(params=params_list, **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] + phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] + + + + for phase in phases: + phase.start_event = None + phase.end_event = None + if rank == 0: + phase.start_event = torch.cuda.Event(enable_timing=True) + phase.end_event = torch.cuda.Event(enable_timing=True) + print('phase: ',phase.name) + + # Export sample images. + grid_size = None + grid_z = None + grid_c = None + if rank == 0: + print('Exporting sample images...') + grid_size, images,segs, labels,poses = setup_snapshot_image_grid(training_set=training_set,random_seed=random.randint(0, 1000000)) + save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + save_image_grid(segs, os.path.join(run_dir, 'segs.jpg'), drange=[0, 255], grid_size=grid_size) + grid_images = (torch.from_numpy(images).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + grid_segs = (torch.from_numpy(segs).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + + #grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + + if G.rendering_kwargs['c_gen_conditioning_zero']: + raise NotImplementedError + grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + else: + #raise NotImplementedError + grid_z = [] + for i in range(labels.shape[0]//2): + sample_z = torch.randn([1, G.z_dim], device=device) + grid_z.append(sample_z) + grid_z.append(sample_z) + grid_z = torch.cat(grid_z,dim=0).split(batch_gpu) + + + grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) + grid_poses = torch.from_numpy(poses).to(device).split(batch_gpu) + + real_shape_real_pose = [] + for real_pose, c in zip(grid_poses, grid_c): + real_shape_pose_param = {'pose': real_pose} + real_shape_real_pose.append( + G_ema.render_meshes(real_shape_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + real_shape_real_pose = np.concatenate(real_shape_real_pose, axis=0) + save_image_grid(real_shape_real_pose, + os.path.join(run_dir, f'mesh_coarse_real_pose.png'), + drange=[0, 255], grid_size=grid_size) + #exit() + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + stats_collector = training_stats.Collector(regex='.*') + stats_metrics = dict() + stats_jsonl = None + stats_tfevents = None + if rank == 0: + stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') + try: + import torch.utils.tensorboard as tensorboard + stats_tfevents = tensorboard.SummaryWriter(run_dir) + except ImportError as err: + print('Skipping tfevents export:', err) + + # Train. + if rank == 0: + print(f'Training for {total_kimg} kimg...') + print() + cur_nimg = resume_kimg * 1000 + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + batch_idx = 0 + if progress_fn is not None: + progress_fn(0, total_kimg) + + + + while True: + # Fetch training data. + with torch.autograd.profiler.record_function('data_fetch'): + + phase_real_img, phase_real_seg, phase_real_c, phase_real_pose = next(training_set_iterator) + + + phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + phase_real_seg = (phase_real_seg.to(device).to(torch.float32) / 255.0).split(batch_gpu) + phase_real_c = phase_real_c.to(device).split(batch_gpu) + phase_real_pose = phase_real_pose.to(device).split(batch_gpu) + + all_gen_z = torch.randn([len(phases) * (batch_size // num_gpus), G.z_dim], device=device) # 4 * 8 + all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split((batch_size // num_gpus))] + + random_idx = [np.random.randint(len(training_set)) for _ in range(len(phases) * (batch_size // num_gpus))] + + + all_gen_c = [training_set.get_label(gen_idx) for gen_idx in random_idx] + all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) + all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split((batch_size // num_gpus))] + + + all_gen_pose = [training_set.get_coarse_pose(gen_idx) for gen_idx in random_idx] + all_gen_pose = torch.from_numpy(np.stack(all_gen_pose)).pin_memory().to(device) + all_gen_pose = [phase_gen_pose.split(batch_gpu) for phase_gen_pose in all_gen_pose.split((batch_size // num_gpus))] + + assert len(phases) == len(all_gen_z) == len(all_gen_c) ==len(all_gen_pose) + # Execute training phases. + for phase, phase_gen_z,phase_gen_c,phase_gen_pose in zip(phases, all_gen_z,all_gen_c,all_gen_pose): # 4 + if batch_idx % phase.interval != 0: + continue + + + if phase.start_event is not None: + phase.start_event.record(torch.cuda.current_stream(device)) + + # Accumulate gradients. + phase.opt.zero_grad(set_to_none=True) + phase.module.requires_grad_(True) + for real_img, real_seg, real_c,real_pose, gen_z,gen_c,gen_pose in \ + zip(phase_real_img, phase_real_seg, phase_real_c, phase_real_pose, phase_gen_z,phase_gen_c,phase_gen_pose): + + loss.accumulate_gradients(phase=phase.name, real_img=real_img,real_seg = real_seg, real_c=real_c,real_pose = real_pose, + gen_z=gen_z,gen_c = gen_c, gen_pose = gen_pose, + + gain=phase.interval, cur_nimg=cur_nimg,cur_nimg_start = resume_kimg * 1000) + phase.module.requires_grad_(False) + + # Update weights. + with torch.autograd.profiler.record_function(phase.name + '_opt'): + + params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None] + if len(params) > 0: + flat = torch.cat([param.grad.flatten() for param in params]) + if num_gpus > 1: + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + phase.opt.step() + + + + # Phase done. + if phase.end_event is not None: + phase.end_event.record(torch.cuda.current_stream(device)) + + # Update G_ema. + with torch.autograd.profiler.record_function('Gema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(G_ema.parameters(), G.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(G_ema.buffers(), G.buffers()): + b_ema.copy_(b) + G_ema.neural_rendering_resolution = G.neural_rendering_resolution + G_ema.rendering_kwargs = G.rendering_kwargs.copy() + + with torch.autograd.profiler.record_function('Dema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(D_ema.parameters(), D.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(D_ema.buffers(), D.buffers()): + b_ema.copy_(b) + + + # Update state. + cur_nimg += batch_size + batch_idx += 1 + + # Execute ADA heuristic. + if (ada_stats is not None) and (batch_idx % ada_interval == 0): + ada_stats.update() + adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) + augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) + + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): + continue + + # Print status line, accumulating the same information in training_stats. + tick_end_time = time.time() + fields = [] + fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] + fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] + fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] + fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] + fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] + fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] + fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] + fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] + torch.cuda.reset_peak_memory_stats() + fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] + + if loss.swapping_prob is not None: + fields += [f"swap prob {training_stats.report0('Progress/swap_prob', float(loss.swapping_prob)):.3f}"] + if loss.neural_rendering_resolution is not None: + fields += [f"render_res {training_stats.report0('Progress/rendering_res', float(loss.neural_rendering_resolution)):.3f}"] + # if loss.noise_alpha is not None: + # fields += [f"noise_alpha {training_stats.report0('Progress/noise_alpha', float(loss.noise_alpha)):.3f}"] + # if loss.noise_scale is not None: + # fields += [f"noise_scale {training_stats.report0('Progress/noise_scale', float(loss.noise_scale)):.3f}"] + + # if loss.predict_label_alpha is not None: + # fields += [f"predict_label_alpha {training_stats.report0('Progress/predict_label_alpha', float(loss.predict_label_alpha)):.3f}"] + + training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) + training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) + if rank == 0: + print(' '.join(fields)) + + # Check for abort. + if (not done) and (abort_fn is not None) and abort_fn(): + done = True + if rank == 0: + print() + print('Aborting...') + + + + + if (rank == 0) and ((image_snapshot_ticks is not None) and (done or (cur_tick % image_snapshot_ticks == 0) ) ): # or (cur_tick<50 and cur_tick % 5 == 0 ) ) # (cur_tick!=0) and + print('gen images...') + with torch.no_grad(): + predicted_real_pose_params_D = [] + for vis_real_img,vis_real_seg, vis_c in zip(grid_images,grid_segs, grid_c): + pose_param = loss.get_pose_params_D(vis_real_img,vis_real_seg, vis_c, cur_nimg) + predicted_real_pose_params_D.append(pose_param) + + predicted_fake_pose_params_G = [] + for vis_z, vis_c in zip(grid_z, grid_c): + pose_param = loss.get_pose_params_G(vis_z, vis_c) + predicted_fake_pose_params_G.append(pose_param) + + + real_pose_mesh = [] + for predicted_real_pose, c in zip(predicted_real_pose_params_D, grid_c): + real_pose_param = {'pose': predicted_real_pose} + real_pose_mesh.append( + G_ema.render_meshes(real_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + real_pose_mesh = np.concatenate(real_pose_mesh, axis=0) + save_image_grid(real_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_real_pose_D.png'), + drange=[0, 255], grid_size=grid_size) + + + snap_pose = predicted_fake_pose_params_G + cond_c = torch.tensor([[ 1.0000e+00, 1.0505e-09, 4.3685e-08, -1.1805e-07, 0.0000e+00, + -9.9951e-01, 2.4033e-02, -1.1805e-07, 4.3714e-08, -2.4033e-02, + -9.9951e-01, 2.6992e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 1.0000e+00, 6.5104e+00, 0.0000e+00, 5.0000e-01, 0.0000e+00, + 6.5104e+00, 5.0000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+00]]).float().to(device) + + + #out = [G_ema(z=z, c=c, noise_mode='const',apply_def = True, pose_params = pose) for z, c, pose in zip(grid_z, grid_c, snap_pose)] + grid_ws = [G_ema.mapping(z, cond_c.expand(z.shape[0], -1),None) for z in grid_z] + out =[G_ema.synthesis(ws, c=c, noise_mode='const',apply_def = True, pose_params = pose) for ws, c,pose in zip(grid_ws, grid_c,snap_pose)] + images = torch.cat([o['image'].cpu() for o in out]).numpy() + #print('images range: ',np.max(images),np.min(images)) + images_raw = torch.cat([o['image_raw'].cpu() for o in out]).numpy() + images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy() + images_alpha = torch.cat([o['image_mask'].cpu() for o in out]).numpy() + #background_raw = torch.cat([o['image_background'].cpu() for o in out]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_0.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_2_raw.png'), drange=[-1,1], grid_size=grid_size) + save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_4_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size) + save_image_grid(images_alpha, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_4_alpha.jpg'), drange=[0, 1], grid_size=grid_size) + #save_image_grid(background_raw, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_4_background.jpg'), drange=[-1, 1], grid_size=grid_size) + with torch.no_grad(): + predicted_fake_pose_params_D = [] + for o,vis_c,vis_pose in zip(out,grid_c,snap_pose): + pose_param = loss.get_pose_params_D(o['image'],o['image_mask'],vis_c, cur_nimg) + predicted_fake_pose_params_D.append(pose_param) + + fake_pose_mesh = [] + for predicted_fake_pose, c in zip(predicted_fake_pose_params_D, grid_c): + fake_pose_param = {'pose': predicted_fake_pose} + fake_pose_mesh.append( + G_ema.render_meshes(fake_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + fake_pose_mesh = np.concatenate(fake_pose_mesh, axis=0) + save_image_grid(fake_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_fake_pose_D.png'), + drange=[0, 255], grid_size=grid_size) + + input_pose_mesh = [] + for input_pose, c in zip(predicted_fake_pose_params_G, grid_c): + input_pose_param = {'pose': input_pose} + input_pose_mesh.append( + G_ema.render_meshes(input_pose_param, resolution=training_set.image_shape[2], cameras=c) + ) + input_pose_mesh = np.concatenate(input_pose_mesh, axis=0) + save_image_grid(input_pose_mesh, + os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_mesh_input_pose_G.png'), + drange=[0, 255], grid_size=grid_size) + + + + + # no_pose_out = [G_ema(z=z, c=c, noise_mode='const', apply_def=False, pose_params=None) for z, c in zip(grid_z, grid_c)] + no_pose_out =[G_ema.synthesis(ws, c=c, noise_mode='const',apply_def = False, pose_params = None) for ws, c in zip(grid_ws, grid_c)] + images = torch.cat([o['image'].cpu() for o in no_pose_out]).numpy() + images_raw = torch.cat([o['image_raw'].cpu() for o in no_pose_out]).numpy() + images_depth = -torch.cat([o['image_depth'].cpu() for o in no_pose_out]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_1_no_pose.png'), drange=[-1, 1], + grid_size=grid_size) + save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_3_no_pose_raw.png'), drange=[-1, 1], + grid_size=grid_size) + save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}_5_no_pose_depth.png'), + drange=[images_depth.min(), images_depth.max()], grid_size=grid_size) + + + + # if (loss.fronzen_D is not None) and ((network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0)): + # if rank ==0 : + # print('update loss.fronzen_D...') + # misc.copy_params_and_buffers(D, loss.fronzen_D, require_all=True) + # Save network snapshot. + snapshot_pkl = None + snapshot_data = None + if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): + snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('D_ema', D_ema), ('augment_pipe', augment_pipe)]: + if module is not None: + if num_gpus > 1: + misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)') + module = copy.deepcopy(module).eval().requires_grad_(False).cpu() + snapshot_data[name] = module + del module # conserve memory + snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') + if rank == 0: + with open(snapshot_pkl, 'wb') as f: + pickle.dump(snapshot_data, f) + + pose_predict_kwargs = { + 'blur_sigma' : loss.blur_sigma, + 'neural_rendering_resolution': loss.neural_rendering_resolution, + 'resample_filter': loss.resample_filter.cpu().numpy().tolist(), + 'filter_mode': loss.filter_mode + } + with open(os.path.join(run_dir, f'pose_predict_kwargs-{cur_nimg//1000:06d}.json'), 'wt') as f: + json.dump(pose_predict_kwargs, f, indent=2) + + + # Evaluate metrics. + if (cur_tick!=0) and (snapshot_data is not None) and (len(metrics) > 0): + if rank == 0: + print(run_dir) + print('Evaluating metrics...') + for metric in metrics: + progress = metric_utils.ProgressMonitor(verbose=True) + # result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], + # dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, + # rank=rank, device=device, progress=progress + # ) + result_dict = metric_main.calc_metric(metric=metric, + G=snapshot_data['G_ema'], + dataset_kwargs=training_set_kwargs, + num_gpus=num_gpus, + rank=rank, + device=device, + metric_pose_sample_mode = metric_pose_sample_mode, + progress=progress, + D = snapshot_data['D'] if metric_pose_sample_mode == 'D_predict' else None, + pose_predict_kwargs = { + 'blur_sigma' : loss.blur_sigma, + 'neural_rendering_resolution': loss.neural_rendering_resolution, + 'resample_filter': loss.resample_filter, + 'filter_mode': loss.filter_mode + } if metric_pose_sample_mode == 'D_predict' else None + ) + + if rank == 0: + metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) + stats_metrics.update(result_dict.results) + del snapshot_data # conserve memory + + # Collect statistics. + for phase in phases: + value = [] + if (phase.start_event is not None) and (phase.end_event is not None): + phase.end_event.synchronize() + value = phase.start_event.elapsed_time(phase.end_event) + training_stats.report0('Timing/' + phase.name, value) + stats_collector.update() + stats_dict = stats_collector.as_dict() + + # Update logs. + timestamp = time.time() + if stats_jsonl is not None: + fields = dict(stats_dict, timestamp=timestamp) + stats_jsonl.write(json.dumps(fields) + '\n') + stats_jsonl.flush() + if stats_tfevents is not None: + global_step = int(cur_nimg / 1e3) + walltime = timestamp - start_time + for name, value in stats_dict.items(): + stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) + for name, value in stats_metrics.items(): + stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) + stats_tfevents.flush() + if progress_fn is not None: + progress_fn(cur_nimg // 1000, total_kimg) + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + if rank == 0: + print() + print('Exiting...') + +#---------------------------------------------------------------------------- diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/__init__.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/__init__.py new file mode 100644 index 0000000..daba665 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# empty \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/math_utils.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/math_utils.py new file mode 100644 index 0000000..4cf9d2b --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_marcher.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_marcher.py new file mode 100644 index 0000000..3c2d1ee --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_marcher.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MipRayMarcher2(nn.Module): + def __init__(self): + super().__init__() + + def run_forward(self, colors, densities, depths, rendering_options): + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + + if rendering_options['clamp_mode'] == 'softplus': + densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better + else: + assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options): + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + + return composite_rgb, composite_depth, weights \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_sampler.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_sampler.py new file mode 100644 index 0000000..00dd07b --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/ray_sampler.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, resolution): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + resolution: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + N, M = cam2world_matrix.shape[0], resolution**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) + y_cam = uv[:, :, 1].view(N, -1) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/renderer.py b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/renderer.py new file mode 100644 index 0000000..a925cf4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/trigrid_rendering/volumetric_rendering/renderer.py @@ -0,0 +1,609 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import math +import torch +from nerf.torch_utils import misc +from nerf.trigrid_rendering.volumetric_rendering.ray_marcher import MipRayMarcher2 +from nerf.trigrid_rendering.volumetric_rendering import math_utils +# from training.aligned_smplx import AlignedSMPLX + +#from training.aligned_smpl import AlignedSMPL +import smplx +from kaolin.ops.mesh import index_vertices_by_faces +from kaolin.metrics.trianglemesh import point_to_mesh_distance + +from nerf.trigrid_rendering.aligned_smpl import AlignedSMPL +import trimesh + + + + +# def generate_planes(): +# """ +# Defines planes by the three vectors that form the "axes" of the +# plane. Should work with arbitrary number of planes and planes of +# arbitrary orientation. +# """ +# return torch.tensor([[[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1]], +# [[1, 0, 0], +# [0, 0, 1], +# [0, 1, 0]], +# [[0, 0, 1], +# [1, 0, 0], +# [0, 1, 0]]], dtype=torch.float32) + +# correct tri-planes, see https://github.com/NVlabs/eg3d/issues/67 +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 1, 0], + [0, 0, 1], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None, triplane_depth=1,render_high_freq = True): + assert padding_mode == 'zeros' + output_features = None + + + _, M, _ = coordinates.shape + coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1).unsqueeze(2) # (N x n_planes) x 1 x 1 x M x 3 + for res_k in plane_features: + plane_feature = plane_features[res_k] + N, n_planes, CD, H, W = plane_feature.shape + # _, M, _ = coordinates.shape + C, D = CD // triplane_depth, triplane_depth + plane_feature = plane_feature.view(N * n_planes, C, D, H, W) + + # coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds + + # projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1).unsqueeze(2) # (N x n_planes) x 1 x 1 x M x 3 + output_feature = torch.nn.functional.grid_sample(plane_feature, projected_coordinates.float(), mode=mode, + padding_mode=padding_mode, align_corners=False).permute(0, + 4, + 3, + 2, + 1).reshape(N, n_planes, M, C) + if output_features is None: + output_features = output_feature + else: + output_features += output_feature + + output_features /= len(plane_features) + + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +def triplane_crop_mask(xyz_unformatted, thresh, boxwarp, allow_bottom=True): + # bw,tc = boxwarp, thresh + bw = boxwarp + tc = boxwarp * thresh + device = xyz_unformatted.device + # xyz = 0.5 * (xyz_unformatted+1) * torch.tensor([-1,1,-1]).to(device)[None,None,:] + xyz = (xyz_unformatted) * torch.tensor([-1,1,-1]).to(device)[None,None,:] + ans = (xyz[:,:,[0,2]].abs() <= (bw/2-tc)).all(dim=-1,keepdim=True) + if allow_bottom: + ans = ans | ( + (xyz[:,:,1:2] <= -(bw/2-tc)) & + (xyz[:,:,[0,2]].abs() <= (bw/2-tc)).all(dim=-1,keepdim=True) + ) + return ~ans +def cull_clouds_mask(denities, thresh): + denities = torch.nn.functional.softplus(denities - 1) # activation bias of -1 makes things initialize better + alpha = 1 - torch.exp(-denities) + return alpha < thresh + + + +class ImportanceRenderer(torch.nn.Module): + def __init__(self, w_dim, num_ws,batch_size,thickness,box_warp,apply_deformation = True): + super().__init__() + self.ray_marcher = MipRayMarcher2() + self.plane_axes = generate_planes() + self.batch_size = batch_size + self.num_betas = 10 + self.apply_deformation = apply_deformation + if apply_deformation: + body_model_smpl = smplx.create('./smplx_models', + model_type='smpl', + gender='neutral', + use_compressed=False, + use_face_contour=True, + num_betas=self.num_betas, + num_expression_coeffs=10, + ext='npz', + batch_size = batch_size + ).cuda() + self.aligned_SMPL = AlignedSMPL(model=body_model_smpl,batch_size=batch_size) + + + + shaped_smpl_data = self.aligned_SMPL.generate_shaped_smpl( + betas=None, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'] + ) + shaped_smpl = shaped_smpl_data['vertices'].detach().contiguous() + align_points = shaped_smpl_data['align_joint_coordinate'].detach().contiguous() + + self.register_buffer('shaped_smpl', shaped_smpl) + self.register_buffer('align_points', align_points) + + # shaped_smpl [B,N,3] + # filter points that outside box + box_side_length = box_warp + # shaped_smpl: B,N,3 + point_mask = shaped_smpl[0:1,:,0] > -box_side_length/2 # 1,N + point_mask = point_mask & (shaped_smpl[0:1,:,0] < box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,1] > -box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,1] < box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,2] > -box_side_length/2) + point_mask = point_mask & (shaped_smpl[0:1,:,2] < box_side_length/2) + point_mask = point_mask.squeeze(0).cuda() # N + + faces = self.aligned_SMPL.faces # [20908, 3] + face_mask = torch.ones(faces.shape[0],dtype=torch.bool).cuda() # [20908] + for i in range(faces.shape[0]): + face_mask[i] = point_mask[faces[i,0]] and point_mask[faces[i,1]] and point_mask[faces[i,2]] + self.register_buffer('face_mask', face_mask) + + self.thickness = thickness + + # shaped_smpl [B,N,3] + # filter points that not on the head + # shaped_smpl: B,N,3 + + # + # point_mask = shaped_smpl[0:1, :, 1] > 0 # 1,N + + point_mask = shaped_smpl[0:1, :, 1] > 0.06 # 1,N + point_mask = point_mask & (shaped_smpl[0:1, :, 2] < -0.0) + + point_mask = point_mask.squeeze(0).cuda() # N + + faces = self.aligned_SMPL.faces # [20908, 3] + head_face_mask = torch.ones(faces.shape[0], dtype=torch.bool).cuda() # [20908] + for i in range(faces.shape[0]): + head_face_mask[i] = point_mask[faces[i, 0]] and point_mask[faces[i, 1]] and point_mask[faces[i, 2]] + self.register_buffer('head_face_mask', head_face_mask) + + self.back_head_depth = None + # + # print('head_face_mask shape:',head_face_mask.shape) + + + def set_batch_size(self,batch_size): + self.batch_size = batch_size + body_model_smpl = smplx.create('./smplx_models', + model_type='smpl', + gender='neutral', + use_compressed=False, + use_face_contour=True, + num_betas=self.num_betas, + num_expression_coeffs=10, + ext='npz', + batch_size=batch_size + ).to(self.aligned_SMPL.model.shapedirs.device) + self.aligned_SMPL.set_model(body_model_smpl) + self.aligned_SMPL.set_batch_size(batch_size) + shaped_smpl_data = self.aligned_SMPL.generate_shaped_smpl( + betas=None, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'] + ) + shaped_smpl = shaped_smpl_data['vertices'].detach().contiguous() + align_points = shaped_smpl_data['align_joint_coordinate'].detach().contiguous() + self.register_buffer('shaped_smpl', shaped_smpl) + self.register_buffer('align_points', align_points) + + + def render_meshes(self, shape_pose_params,resolution,cameras): + images = self.aligned_SMPL.get_visualization(shape_pose_params, resolution, cameras) + return images + + + def get_deformed_coordinate(self, ws, pose_params, original_coordinate): + + + posed_smpl = self.aligned_SMPL.generate_posed_smpl(betas=None, + body_pose=pose_params, + scale=None, # shape_params['scale'], + transl=None, # shape_params['transl'], + align_joint_coordinate=self.align_points)['vertices'] + # misc.assert_shape(posed_smpl, [None, 10475, 3]) + + + mode = 'kaolin' + if mode == 'pytorch3d': + raise NotImplementedError + import pytorch3d.ops + #raise NotImplementedError + with torch.no_grad(): + + smpl_def_on_mesh = self.shaped_smpl - posed_smpl # [B, , 3] + + # find the nearest face in posed_smpl for each vertex in original_coordinate + knn_res = pytorch3d.ops.knn_points(p1=original_coordinate, p2=posed_smpl, K=1) + distance = knn_res[0] # [B, N, 1] + p1_index = knn_res[1].repeat(1, 1, 3) # [B, N, 3] + misc.assert_shape(p1_index, [original_coordinate.shape[0], original_coordinate.shape[1],3]) + + + DistToMesh = distance.squeeze(-1) # [B, N] + + SmplDef = smpl_def_on_mesh.gather(1, p1_index) # [B, N, 3] + mask = DistToMesh < self.thickness# [B, N] + + + scale = 5. + SmplDef1 = SmplDef / torch.exp(DistToMesh.unsqueeze(-1) * scale) # [B, N, 3] + + scale = DistToMesh.unsqueeze(-1) / (self.thickness * 2) * 20 + SmplDef2 = torch.zeros_like(SmplDef).to(SmplDef.device) + + SmplDef = torch.where(mask.unsqueeze(-1), SmplDef1, SmplDef2) # [B, N, 3] + elif mode == 'kaolin': + faces = self.aligned_SMPL.faces.clone() # [20908, 3] + faces = faces[self.face_mask, :] + # find the nearest face in shaped_smplx for each vertex in original_coordinate + vertex_faces = posed_smpl.clone() # [B, 6085, 3] + + with torch.no_grad(): + face_vertices = index_vertices_by_faces(vertex_faces, faces) + distance, index, dist_type = point_to_mesh_distance(original_coordinate, face_vertices) # B, N + distance = torch.sqrt(distance) # [B, N, 1] + selected_posed_smpl_vertices = [] + selected_shaped_smpl_vertices = [] + + for i in range(original_coordinate.shape[0]): + selected_face = faces[index[i]] + selected_posed_smpl_vertices.append(index_vertices_by_faces(posed_smpl[i:i + 1], + selected_face)) # [1, N, 3, 3] + selected_shaped_smpl_vertices.append(index_vertices_by_faces(self.shaped_smpl[i:i + 1], + selected_face)) # [1, N, 3, 3] + + selected_posed_smpl_vertices = torch.cat(selected_posed_smpl_vertices, dim=0) # [B, N, 3, 3] + selected_shaped_smpl_vertices = torch.cat(selected_shaped_smpl_vertices, dim=0) # [B, N, 3, 3] + + y_axes = torch.cross(selected_posed_smpl_vertices[:, :, 1, :] - selected_posed_smpl_vertices[:, :, 0, :], + selected_posed_smpl_vertices[:, :, 2, :] - selected_posed_smpl_vertices[:, :, 0, + :]) # [B, N, 3] + y_axes = y_axes / torch.norm(y_axes, dim=2, keepdim=True) # [B, N, 3] + + x_axes = selected_posed_smpl_vertices[:, :, 1, :] - selected_posed_smpl_vertices[:, :, 0, :] # [B, N, 3] + x_axes = x_axes / torch.norm(x_axes, dim=2, keepdim=True) # [B, N, 3] + + z_axes = torch.cross(x_axes, y_axes) # [B, N, 3] + + posed_smpl_coordinate = torch.stack( + [torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * x_axes, dim=2), + torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * y_axes, dim=2), + torch.sum((original_coordinate - selected_posed_smpl_vertices[:, :, 0, :]) * z_axes, dim=2)], + dim=2) # [B, N, 3] + del x_axes, y_axes, z_axes + y_axes = torch.cross(selected_shaped_smpl_vertices[:, :, 1, :] - selected_shaped_smpl_vertices[:, :, 0, :], + selected_shaped_smpl_vertices[:, :, 2, :] - selected_shaped_smpl_vertices[:, :, 0, :]) + y_axes = y_axes / torch.norm(y_axes, dim=2, keepdim=True) + + x_axes = selected_shaped_smpl_vertices[:, :, 1, :] - selected_shaped_smpl_vertices[:, :, 0, :] + x_axes = x_axes / torch.norm(x_axes, dim=2, keepdim=True) + + z_axes = torch.cross(x_axes, y_axes) + + new_coordinate = posed_smpl_coordinate[:, :, 0:1] * x_axes + \ + posed_smpl_coordinate[:, :, 1:2] * y_axes + \ + posed_smpl_coordinate[:, :, 2:3] * z_axes + \ + selected_shaped_smpl_vertices[:, :, 0, :] # [B, N, 3] + + SmplDef = new_coordinate - original_coordinate # [B, N, 3] + + DistToMesh = distance.unsqueeze(-1) # [B, N, 1] + + mask = DistToMesh < self.thickness # [B, N,1] + + SmplDef2 = torch.zeros_like(SmplDef).to(SmplDef.device) + SmplDef = torch.where(mask, SmplDef, SmplDef2) # [B, N, 3] + + else: + raise NotImplementedError + + original_coordinate = original_coordinate + SmplDef + return original_coordinate + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, apply_def = False, ws = None, pose_params = None, triplane_crop=0.1, cull_clouds=None, binarize_clouds=None ): + _ = ws + if apply_def: + assert pose_params is not None + else: + assert pose_params is None + + self.plane_axes = self.plane_axes.to(ray_origins.device) + + # check if grad = 0 + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape + + # Coarse Pass + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + # deform the sample_coordinates + if apply_def: + sample_coordinates = self.get_deformed_coordinate(None, pose_params, sample_coordinates) + + + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + colors_coarse = out['rgb'] + densities_coarse = out['sigma'] + + xyz_coarse = out['xyz'] + + if triplane_crop: + # print(xyz_fine.amin(dim=(0,1))) + # print(xyz_fine.amax(dim=(0,1))) + cropmask = triplane_crop_mask(xyz_coarse, triplane_crop, rendering_options['box_warp']) + densities_coarse[cropmask] = -1e3 + if binarize_clouds: + ccmask = cull_clouds_mask(densities_coarse, binarize_clouds) + densities_coarse[ccmask] = -1e3 + densities_coarse[~ccmask] = 1e3 + elif cull_clouds: + ccmask = cull_clouds_mask(densities_coarse, cull_clouds) + densities_coarse[ccmask] = -1e3 + + colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) + densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) + xyz_coarse = xyz_coarse.reshape(batch_size, num_rays, samples_per_ray, xyz_coarse.shape[-1]) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + # deform the sample_coordinates + if apply_def: + sample_coordinates = self.get_deformed_coordinate(None, pose_params, sample_coordinates) + + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + colors_fine = out['rgb'] + densities_fine = out['sigma'] + xyz_fine = out['xyz'] + if triplane_crop: + # print(xyz_fine.amin(dim=(0,1))) + # print(xyz_fine.amax(dim=(0,1))) + cropmask = triplane_crop_mask(xyz_fine, triplane_crop, rendering_options['box_warp']) + densities_fine[cropmask] = -1e3 + if binarize_clouds: + ccmask = cull_clouds_mask(densities_fine, binarize_clouds) + densities_fine[ccmask] = -1e3 + densities_fine[~ccmask] = 1e3 + elif cull_clouds: + ccmask = cull_clouds_mask(densities_fine, cull_clouds) + densities_fine[ccmask] = -1e3 + xyz_fine = xyz_fine.reshape(batch_size, num_rays, N_importance, xyz_fine.shape[-1]) + colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) + densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) + + # all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + # depths_fine, colors_fine, densities_fine) + all_depths, all_colors, all_densities, all_xyz = self.unify_samples( + depths_coarse, colors_coarse, densities_coarse, xyz_coarse, + depths_fine, colors_fine, densities_fine, xyz_fine, + ) + + # Aggregate + # rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + + all_colors_ = torch.cat([all_colors, all_xyz], dim=-1) + rgb_final_, depth_final, weights = self.ray_marcher(all_colors_, all_densities, all_depths, rendering_options) + rgb_final = rgb_final_[...,:-3] + xyz_final = rgb_final_[...,-3:] + else: + # rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + colors_coarse_ = torch.cat([colors_coarse, xyz_coarse], dim=-1) + rgb_final_, depth_final, weights = self.ray_marcher(colors_coarse_, densities_coarse, depths_coarse, rendering_options) + rgb_final = rgb_final_[...,:-3] + xyz_final = rgb_final_[...,-3:] + + + output = {'rgb_final': rgb_final, 'depth_final': depth_final, 'weights': weights} + + return output + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + self.plane_axes = self.plane_axes.to(planes[list(planes.keys())[0]].device) + sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', + box_warp=options['box_warp'], triplane_depth=options['triplane_depth']) + + out = decoder(sampled_features, sample_directions) + + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + out['xyz'] = sample_coordinates#.permute(0,2,1)[...,None] + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + # def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + # all_depths = torch.cat([depths1, depths2], dim = -2) + # all_colors = torch.cat([colors1, colors2], dim = -2) + # all_densities = torch.cat([densities1, densities2], dim = -2) + + # _, indices = torch.sort(all_depths, dim=-2) + # all_depths = torch.gather(all_depths, -2, indices) + # all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + # all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + # return all_depths, all_colors, all_densities + def unify_samples(self, depths1, colors1, densities1, xyz1, depths2, colors2, densities2, xyz2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_xyz = torch.cat([xyz1, xyz2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_xyz = torch.gather(all_xyz, -2, indices.expand(-1, -1, -1, all_xyz.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities, all_xyz + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom= -90 and azimuth_val < 90: + if azimuth_val >= 0: + r = 1 - azimuth_val / 90 + else: + r = 1 + azimuth_val / 90 + start_z = embeddings['front'] + end_z = embeddings['side'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0) + if r > 0.8: + front_neg_w = 0.0 + else: + front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w + if r < 0.2: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-(1 - r) * opt.side_decay_factor) * opt.negative_w + + weights = torch.tensor([1.0, front_neg_w, side_neg_w]) + else: + if azimuth_val >= 0: + r = 1 - (azimuth_val - 90) / 90 + else: + r = 1 + (azimuth_val + 90) / 90 + start_z = embeddings['side'] + end_z = embeddings['back'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0) + front_neg_w = opt.negative_w + if r > 0.8: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2 + + weights = torch.tensor([1.0, side_neg_w, front_neg_w]) + return text_z, weights.to(text_z.device) + + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + error_map: [B, 128 * 128], sample probability based on training error + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device)) + i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 + j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H * W) + + if error_map is None: + inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + else: + + # weighted sample on a low-reso grid + inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) + + # map to the original resolution with random perturb. + inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. + sx, sy = H / 128, W / 128 + inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) + inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) + inds = inds_x * W + inds_y + + results['inds_coarse'] = inds_coarse # need this when updating error_map + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + results['inds'] = inds + + else: + inds = torch.arange(H * W, device=device).expand([B, H * W]) + + zs = - torch.ones_like(i) + xs = - (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + # directions = safe_normalize(directions) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return results + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +class TrigridTrainer(object): + def __init__(self, + argv, # command line args + name, # name of this experiment + opt, # extra conf + model, # network + teacher_model, + guidance, # guidance network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], + # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.argv = argv + self.name = name + self.opt = opt + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device( + f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + self.as_latent = True + self.vgg16 = None + model.to(self.device) + teacher_model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + + teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model) + teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[local_rank]) + self.model = model + self.teacher_model = teacher_model + + # guide model + self.guidance = guidance + self.embeddings = {} + + # text prompt / images + if self.guidance is not None: + for key in self.guidance: + for p in self.guidance[key].parameters(): + p.requires_grad = False + self.embeddings[key] = {} + self.prepare_embeddings() + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + if self.opt.images is not None: + self.pearson = PearsonCorrCoef().to(self.device) + + if optimizer is None: + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # variable init + self.total_train_t = 0 + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'latent_trigrid_fit_checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth" + os.makedirs(self.ckpt_path, exist_ok=True) + + # Save a copy of image_config in the experiment workspace + if opt.image_config is not None: + shutil.copyfile(opt.image_config, os.path.join(self.workspace, os.path.basename(opt.image_config))) + + # Save a copy of images in the experiment workspace + if opt.images is not None: + for image_file in opt.images: + shutil.copyfile(image_file, os.path.join(self.workspace, os.path.basename(image_file))) + + self.log(f'[INFO] Cmdline: {self.argv}') + self.log(f'[INFO] opt: {self.opt}') + self.log( + f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "latest_model": + self.log("[INFO] Loading latest checkpoint (model only)...") + self.load_checkpoint(model_only=True) + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log(f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + # calculate the text embs. + @torch.no_grad() + def prepare_embeddings(self): + + # text embeddings (stable-diffusion) + if self.opt.text is not None: + + if 'SD' in self.guidance: + self.embeddings['SD']['default'] = self.guidance['SD'].get_text_embeds([self.opt.text]) + self.embeddings['SD']['uncond'] = self.guidance['SD'].get_text_embeds([self.opt.negative]) + + for d in ['front', 'side', 'back']: + self.embeddings['SD'][d] = self.guidance['SD'].get_text_embeds([f"{self.opt.text}, {d} view"]) + + if 'IF' in self.guidance: + self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text]) + self.embeddings['IF']['uncond'] = self.guidance['IF'].get_text_embeds([self.opt.negative]) + + for d in ['front', 'side', 'back']: + self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([f"{self.opt.text}, {d} view"]) + + if 'clip' in self.guidance: + self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text) + + if self.opt.images is not None: + + h = int(self.opt.known_view_scale * self.opt.h) + w = int(self.opt.known_view_scale * self.opt.w) + + # load processed image + for image in self.opt.images: + assert image.endswith( + '_rgba.png') # the rest of this code assumes that the _rgba image has been passed. + rgbas = [cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA) for image in + self.opt.images] + rgba_hw = np.stack( + [cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas]) + rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:]) + self.rgb = torch.from_numpy(rgb_hw).permute(0, 3, 1, 2).contiguous().to(self.device) + self.mask = torch.from_numpy(rgba_hw[..., 3] > 0.5).to(self.device) + print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}') + + # load depth + depth_paths = [image.replace('_rgba.png', '_depth.png') for image in self.opt.images] + depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths] + depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths]) + self.depth = torch.from_numpy(depth.astype(np.float32) / 255).to( + self.device) # TODO: this should be mapped to FP16 + print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}') + + # load normal # TODO: don't load if normal loss is 0 + normal_paths = [image.replace('_rgba.png', '_normal.png') for image in self.opt.images] + normals = [cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) for normal_path in normal_paths] + normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals]) + self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device) + print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}') + + # encode embeddings for zero123 + if 'zero123' in self.guidance: + rgba_256 = np.stack( + [cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in + rgbas]) + rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:]) + rgb_256 = torch.from_numpy(rgbs_256).permute(0, 3, 1, 2).contiguous().to(self.device) + guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256) + self.embeddings['zero123']['default'] = { + 'zero123_ws': self.opt.zero123_ws, + 'c_crossattn': guidance_embeds[0], + 'c_concat': guidance_embeds[1], + 'ref_polars': self.opt.ref_polars, + 'ref_azimuths': self.opt.ref_azimuths, + 'ref_radii': self.opt.ref_radii, + } + + if 'clip' in self.guidance: + self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb) + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + # print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + ### ------------------------------ + + def train_step(self, data, save_guidance_path: Path = None): + """ + Args: + save_guidance_path: an image that combines the NeRF render, the added latent noise, + the denoised result and optionally the fully-denoised image. + """ + + # perform RGBD loss instead of SDS if is image-conditioned + do_rgbd_loss = self.opt.images is not None and \ + (self.global_step % self.opt.known_view_interval == 0) + + # override random camera with fixed known camera + if do_rgbd_loss: + data = self.default_view_data + + # experiment iterations ratio + # i.e. what proportion of this experiment have we completed (in terms of iterations) so far? + exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / ( + self.opt.exp_end_iter - self.opt.exp_start_iter) + + # progressively relaxing view range + if self.opt.progressive_view: + r = min(1.0, self.opt.progressive_view_init_ratio + 2.0 * exp_iter_ratio) + self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r, + self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r] + self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r, + self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r] + self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r, + self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r] + self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r, + self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r] + + # progressively increase max_level + if self.opt.progressive_level: + self.model.max_level = min(1.0, 0.25 + 2.0 * exp_iter_ratio) + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + mvp = data['mvp'] # [B, 4, 4] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + teacher_rays_o = data['teacher_rays_o'] # [B, N, 3] + teacher_rays_d = data['teacher_rays_d'] # [B, N, 3] + teacher_H = data['teacher_H'] + teacher_W = data['teacher_W'] + + # When ref_data has B images > opt.batch_size + if B > self.opt.batch_size: + # choose batch_size images out of those B images + choice = torch.randperm(B)[:self.opt.batch_size] + B = self.opt.batch_size + rays_o = rays_o[choice] + rays_d = rays_d[choice] + mvp = mvp[choice] + + if do_rgbd_loss: + ambient_ratio = 1.0 + shading = 'lambertian' # use lambertian instead of albedo to get normal + binarize = False + bg_color = torch.rand((B * N, 3), device=rays_o.device) + + # add camera noise to avoid grid-like artifact + if self.opt.known_view_noise_scale > 0: + noise_scale = self.opt.known_view_noise_scale # * (1 - self.global_step / self.opt.iters) + rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale + + elif exp_iter_ratio <= self.opt.latent_iter_ratio: + ambient_ratio = 1.0 + shading = 'normal' + binarize = False + bg_color = None + + else: + if exp_iter_ratio <= self.opt.albedo_iter_ratio: + ambient_ratio = 1.0 + shading = 'albedo' + else: + # random shading + ambient_ratio = self.opt.min_ambient_ratio + (1.0 - self.opt.min_ambient_ratio) * random.random() + rand = random.random() + if rand >= (1.0 - self.opt.textureless_ratio): + shading = 'textureless' + else: + shading = 'lambertian' + + # random weights binarization (like mobile-nerf) [NOT WORKING NOW] + # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters) + # binarize = random.random() < binarize_thresh + binarize = False + + # random background + rand = random.random() + # if self.opt.bg_radius > 0 and rand > 0.5: + if self.opt.learnable_bg: + bg_color = None # use bg_net + elif self.opt.noise_bg: + # B, 3, H, W + # bg_color = torch.randn(B, 3, H, W).to(self.device) + # bg_color = bg_color * + # self.guidance['SD']. + raise NotImplementedError + else: + bg_color = torch.rand(3).to(self.device) # single color random bg + + outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, + ambient_ratio=ambient_ratio, shading=shading, binarize=binarize, as_latent=True) + + if self.as_latent: + # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D) + pred_latent = outputs['image'].reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W] + pred_rgb = self.guidance['SD'].decode_latents(pred_latent).permute(0, 2, 3, 1).contiguous() # [B, H, W, 3] + else: + raise NotImplementedError + + pred_depth = outputs['depth'].squeeze(-1) # .reshape(B, H, W) + + with torch.no_grad(): + teacher_output = self.teacher_model.render(teacher_rays_o, teacher_rays_d, mvp, teacher_H, teacher_W, + staged=True, perturb=True, bg_color=bg_color, + ambient_ratio=ambient_ratio, shading=shading, binarize=binarize, + as_latent=False) + + teacher_rgb = teacher_output['image'] + teacher_rgb = teacher_rgb # .reshape(B, H, W, 3) + + teacher_latent = self.guidance['SD'].encode_imgs( + teacher_rgb.permute(0, 3, 1, 2).contiguous()) # [B, 4, H, W] + + teacher_depth = teacher_output['depth'].squeeze(-1) # [B, 1, H, W] + teacher_depth = F.interpolate(teacher_depth.unsqueeze(1), size=pred_depth.shape[1:3], + mode='nearest').squeeze(1) # [B, H, W] + + assert teacher_latent.shape == pred_latent.shape, f"teacher_latent.shape {teacher_latent.shape} != pred_rgb.shape {pred_latent.shape}" + assert teacher_depth.shape == pred_depth.shape, f"teacher_depth.shape {teacher_depth.shape} != pred_depth.shape {pred_depth.shape}" + + loss = 0 + losses = {} + #print(pred_latent.min(), pred_latent.max(), teacher_latent.min(), teacher_latent.max()) + latent_mse_loss = F.mse_loss(pred_latent, teacher_latent) + loss = loss + latent_mse_loss + losses['latent_mse_loss'] = latent_mse_loss + + rgb_mse_loss = F.mse_loss(pred_rgb, teacher_rgb)*20 + loss = loss + rgb_mse_loss + losses['rgb_mse_loss'] = rgb_mse_loss + + # rgb_perceptual_loss = self.perceptual_loss(pred_rgb, teacher_rgb) + # loss = loss + rgb_perceptual_loss + # losses['rgb_perceptual_loss'] = rgb_perceptual_loss + + # depth_mse_loss = F.mse_loss(pred_depth, teacher_depth) + # loss = loss + depth_mse_loss + # losses['depth_mse_loss'] = depth_mse_loss + + return pred_latent, pred_depth, teacher_rgb, teacher_depth, loss, losses + + def perceptual_loss(self, synth_images, target): + ''' + + :param synth_images: [0, 1] , [B, 3, H, W] + :param target: [0, 1] , [B, 3, H, W] + :return: + ''' + synth_images = synth_images.permute(0, 3, 1, 2).contiguous() + target = target.permute(0, 3, 1, 2).contiguous() + + if self.vgg16 is None: + url = './pretrained/vgg16.pt' + with open(url, 'rb') as f: + self.vgg16 = torch.jit.load(f).eval().to(self.device) + + target_images = target * 255 # [-1, 1] -> [0, 255] + if target_images.shape[2] > 256: + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_features = self.vgg16(target_images, resize_images=False, return_lpips=True) + + synth_images = synth_images * 255 # [-1, 1] -> [0, 255] + if synth_images.shape[2] > 256: + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_features = self.vgg16(synth_images, resize_images=False, return_lpips=True) + dist = (target_features - synth_features).square().sum() * 0.1 + + return dist + + def post_train_step(self): + + # unscale grad before modifying it! + # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping + self.scaler.unscale_(self.optimizer) + + # clip grad + if self.opt.grad_clip >= 0: + torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip) + + if not self.opt.dmtet and self.opt.backbone == 'grid': + + if self.opt.lambda_tv > 0: + lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv + self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound) + if self.opt.lambda_wd > 0: + self.model.encoder.grad_weight_decay(self.opt.lambda_wd) + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + mvp = data['mvp'] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + teacher_rays_o = data['teacher_rays_o'] # [B, N, 3] + teacher_rays_d = data['teacher_rays_d'] # [B, N, 3] + teacher_H = data['teacher_H'] + teacher_W = data['teacher_W'] + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + + outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=False, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, bg_color=None, as_latent=True) + + if self.as_latent: # always True + # from B, H, W, C to B, C, H, W + pred_rgb = self.guidance['SD'].decode_latents(outputs['image'].permute(0, 3, 1, 2).contiguous()).permute(0, + 2, + 3, + 1).contiguous() + else: + pred_rgb = outputs['image'] + + pred_rgb = pred_rgb # .reshape(B, H, W, 3) + pred_depth = outputs['depth'].squeeze(-1) # .reshape(B, H, W) + + with torch.no_grad(): + teacher_output = self.teacher_model.render(teacher_rays_o, teacher_rays_d, mvp, teacher_H, teacher_W, + staged=True, perturb=False, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, bg_color=None, + as_latent=False) + + teacher_rgb = teacher_output['image'] + teacher_rgb = teacher_rgb # .reshape(B, H, W, 3) + + teacher_depth = teacher_output['depth'].squeeze(-1) # [B, 1, H, W] + teacher_depth = F.interpolate(teacher_depth.unsqueeze(1), size=pred_depth.shape[1:3], mode='nearest').squeeze( + 1) # [B, H, W] + + assert teacher_rgb.shape == pred_rgb.shape, f"teacher_rgb.shape {teacher_rgb.shape} != pred_rgb.shape {pred_rgb.shape}" + assert teacher_depth.shape == pred_depth.shape, f"teacher_depth.shape {teacher_depth.shape} != pred_depth.shape {pred_depth.shape}" + + # dummy + loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype) + + return pred_rgb, pred_depth, teacher_rgb, teacher_depth, loss + + def test_step(self, data, bg_color=None, perturb=False): + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + mvp = data['mvp'] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + teacher_rays_o = data['teacher_rays_o'] # [B, N, 3] + teacher_rays_d = data['teacher_rays_d'] # [B, N, 3] + teacher_H = data['teacher_H'] + teacher_W = data['teacher_W'] + + if bg_color is not None: + bg_color = bg_color.to(rays_o.device) + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + + outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=perturb, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color, as_latent=True) + + if self.as_latent: # always True + # from B, H, W, C to B, C, H, W + pred_rgb = self.guidance['SD'].decode_latents(outputs['image'].permute(0, 3, 1, 2).contiguous()).permute(0, + 2, + 3, + 1).contiguous() + else: + pred_rgb = outputs['image'] + + pred_rgb = pred_rgb # .reshape(B, H, W, 3) + pred_depth = outputs['depth'].squeeze(-1) # .reshape(B, H, W) + + with torch.no_grad(): + teacher_output = self.teacher_model.render(teacher_rays_o, teacher_rays_d, mvp, teacher_H, teacher_W, + staged=True, perturb=perturb, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color, + as_latent=False) + + teacher_rgb = teacher_output['image'] + teacher_rgb = teacher_rgb # .reshape(B, H, W, 3) + + teacher_depth = teacher_output['depth'].squeeze(-1) + teacher_depth = F.interpolate(teacher_depth.unsqueeze(1), size=pred_depth.shape[1:3], mode='nearest').squeeze( + 1) # [B, H, W] + + assert teacher_rgb.shape == pred_rgb.shape, f"teacher_rgb.shape {teacher_rgb.shape} != pred_rgb.shape {pred_rgb.shape}" + assert teacher_depth.shape == pred_depth.shape, f"teacher_depth.shape {teacher_depth.shape} != pred_depth.shape {pred_depth.shape}" + + return pred_rgb, pred_depth, teacher_rgb, teacher_depth + + def save_mesh(self, loader=None, save_path=None): + + if save_path is None: + save_path = os.path.join(self.workspace, 'mesh') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(save_path, exist_ok=True) + + self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, + decimate_target=self.opt.decimate_target) + + self.log(f"==> Finished saving mesh.") + + ### ------------------------------ + + def train(self, train_loader, valid_loader, test_loader, max_epochs): + + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "latent_trigrid_fit_run", self.name)) + + start_t = time.time() + self.evaluate_one_epoch(valid_loader) + for epoch in range(self.epoch + 1, max_epochs + 1): + self.epoch = epoch + + self.train_one_epoch(train_loader, max_epochs) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.opt.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_checkpoint(full=False, best=True) + + if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs: + self.test(test_loader) + + end_t = time.time() + + self.total_train_t = end_t - start_t + self.total_train_t + + self.log(f"[INFO] training takes {(self.total_train_t) / 60:.4f} minutes.") + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader, name=None): + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader, name) + self.use_tensorboardX = use_tensorboardX + + def test(self, loader, save_path=None, name=None, write_video=True): + + if save_path is None: + save_path = os.path.join(self.workspace, 'latent_trigrid_fit_results') + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + os.makedirs(save_path, exist_ok=True) + + self.log(f"==> Start Test, save results to {save_path}") + + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + self.model.eval() + self.teacher_model.eval() + + if write_video: + all_preds = [] + all_preds_depth = [] + + with torch.no_grad(): + + for i, data in enumerate(loader): + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, teacher_rgb, teacher_depth = self.test_step(data) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6) + pred_depth = (pred_depth * 255).astype(np.uint8) + + teacher_rgb = teacher_rgb[0].detach().cpu().numpy() + teacher_rgb = (teacher_rgb * 255).astype(np.uint8) + + teacher_depth = teacher_depth[0].detach().cpu().numpy() + teacher_depth = (teacher_depth - teacher_depth.min()) / ( + teacher_depth.max() - teacher_depth.min() + 1e-6) + teacher_depth = (teacher_depth * 255).astype(np.uint8) + + pred = np.concatenate([pred, teacher_rgb], axis=1) + pred_depth = np.concatenate([pred_depth, teacher_depth], axis=1) + + if write_video: + all_preds.append(pred) + all_preds_depth.append(pred_depth) + else: + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), + cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth) + + pbar.update(loader.batch_size) + + if write_video: + all_preds = np.stack(all_preds, axis=0) + all_preds_depth = np.stack(all_preds_depth, axis=0) + print('save video...', os.path.join(save_path, f'{name}_rgb.mp4'), + os.path.join(save_path, f'{name}_depth.mp4')) + imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, + macro_block_size=1) + imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, + macro_block_size=1) + + self.log(f"==> Finished Test.") + + def train_one_epoch(self, loader, max_epochs): + self.log( + f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + total_latent_mse_loss = 0 + total_rgb_mse_loss = 0 + # total_rgb_perceptual_loss = 0 + # total_depth_mse_loss = 0 + + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + if self.opt.save_guidance: + save_guidance_folder = Path(self.workspace) / 'guidance' + save_guidance_folder.mkdir(parents=True, exist_ok=True) + + for data in loader: + + # update grid every 16 steps + if ( + self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + if self.opt.save_guidance and (self.global_step % self.opt.save_guidance_interval == 0): + save_guidance_path = save_guidance_folder / f'step_{self.global_step:07d}.png' + else: + save_guidance_path = None + pred_rgbs, pred_depths, teacher_rgbs, teacher_depths, loss, losses = self.train_step(data, + save_guidance_path=save_guidance_path) + + # hooked grad clipping for RGB space + if self.opt.grad_clip_rgb >= 0: + def _hook(grad): + if self.opt.fp16: + # correctly handle the scale + grad_scale = self.scaler._get_scale_async() + return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb) + else: + return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb) + + pred_rgbs.register_hook(_hook) + # pred_rgbs.retain_grad() + + self.scaler.scale(loss).backward() + + self.post_train_step() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + total_latent_mse_loss += losses['latent_mse_loss'].item() + total_rgb_mse_loss += losses['rgb_mse_loss'].item() + # total_rgb_perceptual_loss += losses['rgb_perceptual_loss'].item() + # total_depth_mse_loss += losses['depth_mse_loss'].item() + + if self.local_rank == 0: + # if self.report_metric_at_train: + # for metric in self.metrics: + # metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description( + f"loss={loss_val:.4f} ({total_loss / self.local_step:.4f}), " + f"latent_mse_loss={losses['latent_mse_loss'].item():.4f} ({total_latent_mse_loss / self.local_step:.4f}), " + f"rgb_mse_loss={losses['rgb_mse_loss'].item():.4f} ({total_rgb_mse_loss / self.local_step:.4f}), " + # f"rgb_perceptual_loss={losses['rgb_perceptual_loss'].item():.4f} ({total_rgb_perceptual_loss / self.local_step:.4f}), " + # f"depth_mse_loss={losses['depth_mse_loss'].item():.4f} ({total_depth_mse_loss / self.local_step:.4f}), " + f"lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss / self.local_step:.4f})") + pbar.update(loader.batch_size) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + cpu_mem, gpu_mem = get_CPU_mem(), get_GPU_mem()[0] + self.log( + f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Finished Epoch {self.epoch}/{max_epochs}. CPU={cpu_mem:.1f}GB, GPU={gpu_mem:.1f}GB.") + + def evaluate_one_epoch(self, loader, name=None): + self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...") + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + + for data in loader: + self.local_step += 1 + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, teacher_rgb, teacher_depth, loss = self.eval_step(data) + + # all_gather/reduce the statistics (NCCL only support all_*) + if self.world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss = loss / self.world_size + + preds_list = [torch.zeros_like(preds).to(self.device) for _ in + range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_list, preds) + preds = torch.cat(preds_list, dim=0) + + preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in + range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_depth_list, preds_depth) + preds_depth = torch.cat(preds_depth_list, dim=0) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + # save image + save_path = os.path.join(self.workspace, 'latent_trigrid_fit_validation', + f'{name}_{self.local_step:04d}_rgb.png') + save_path_depth = os.path.join(self.workspace, 'latent_trigrid_fit_validation', + f'{name}_{self.local_step:04d}_depth.png') + + # self.log(f"==> Saving validation image to {save_path}") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6) + pred_depth = (pred_depth * 255).astype(np.uint8) + + teacher_rgb = teacher_rgb[0].detach().cpu().numpy() + teacher_rgb = (teacher_rgb * 255).astype(np.uint8) + + teacher_depth = teacher_depth[0].detach().cpu().numpy() + teacher_depth = (teacher_depth - teacher_depth.min()) / ( + teacher_depth.max() - teacher_depth.min() + 1e-6) + teacher_depth = (teacher_depth * 255).astype(np.uint8) + + pred = np.concatenate((pred, teacher_rgb), axis=1) + pred_depth = np.concatenate((pred_depth, teacher_depth), axis=1) + + cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(save_path_depth, pred_depth) + + pbar.set_description(f"loss={loss_val:.4f} ({total_loss / self.local_step:.4f})") + pbar.update(loader.batch_size) + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + self.stats["results"].append( + result if self.best_mode == 'min' else - result) # if max mode, use -result + else: + self.stats["results"].append(average_loss) # if no metric, choose best by min loss + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + if self.ema is not None: + self.ema.restore() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, name=None, full=False, best=False): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + } + + if self.model.cuda_ray: + state['mean_density'] = self.model.mean_density + + if self.opt.dmtet: + state['tet_scale'] = self.model.tet_scale.cpu().numpy() + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{name}.pth" + + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0)) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, os.path.join(self.ckpt_path, file_path)) + + else: + if len(self.stats["results"]) > 0: + # always save best since loss cannot reflect performance. + if True: + # self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + # self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + try: + self.ema.load_state_dict(checkpoint_dict['ema']) + self.log("[INFO] loaded EMA.") + except: + self.log("[WARN] failed to loaded EMA.") + + if self.model.cuda_ray: + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + + if self.opt.dmtet: + if 'tet_scale' in checkpoint_dict: + new_scale = torch.from_numpy(checkpoint_dict['tet_scale']).to(self.device) + self.model.verts *= new_scale / self.model.tet_scale + self.model.tet_scale = new_scale + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") + + +def get_CPU_mem(): + return psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3 + + +def get_GPU_mem(): + num = torch.cuda.device_count() + mem, mems = 0, [] + for i in range(num): + mem_free, mem_total = torch.cuda.mem_get_info(i) + mems.append(int(((mem_total - mem_free) / 1024 ** 3) * 1000) / 1000) + mem += mems[-1] + return mem, mems \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/nerf/utils.py b/stable-dreamfusion-3DPortrait/nerf/utils.py new file mode 100644 index 0000000..0c0bd02 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/nerf/utils.py @@ -0,0 +1,1353 @@ +import os +import gc +import glob +import tqdm +import math +import imageio +import psutil +from pathlib import Path +import random +import shutil +import warnings +import tensorboardX + +import numpy as np + +import time + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +import torchvision.transforms.functional as TF +from torchmetrics import PearsonCorrCoef + +from rich.console import Console +from torch_ema import ExponentialMovingAverage + +from packaging import version as pver + +def adjust_text_embeddings(embeddings, azimuth, opt): + text_z_list = [] + weights_list = [] + K = 0 + for b in range(azimuth.shape[0]): + text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth[b], opt) + K = max(K, weights_.shape[0]) + text_z_list.append(text_z_) + weights_list.append(weights_) + + # Interleave text_embeddings from different dirs to form a batch + text_embeddings = [] + for i in range(K): + for text_z in text_z_list: + # if uneven length, pad with the first embedding + text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0]) + text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768] + + # Interleave weights from different dirs to form a batch + weights = [] + for i in range(K): + for weights_ in weights_list: + weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0])) + weights = torch.stack(weights, dim=0) # [B * K] + return text_embeddings, weights + +def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt): + if azimuth_val >= -90 and azimuth_val < 90: + if azimuth_val >= 0: + r = 1 - azimuth_val / 90 + else: + r = 1 + azimuth_val / 90 + start_z = embeddings['front'] + end_z = embeddings['side'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0) + if r > 0.8: + front_neg_w = 0.0 + else: + front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w + if r < 0.2: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w + + weights = torch.tensor([1.0, front_neg_w, side_neg_w]) + else: + if azimuth_val >= 0: + r = 1 - (azimuth_val - 90) / 90 + else: + r = 1 + (azimuth_val + 90) / 90 + start_z = embeddings['side'] + end_z = embeddings['back'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0) + front_neg_w = opt.negative_w + if r > 0.8: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2 + + weights = torch.tensor([1.0, side_neg_w, front_neg_w]) + return text_z, weights.to(text_z.device) + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + error_map: [B, 128 * 128], sample probability based on training error + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) + i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H*W) + + if error_map is None: + inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + else: + + # weighted sample on a low-reso grid + inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) + + # map to the original resolution with random perturb. + inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. + sx, sy = H / 128, W / 128 + inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) + inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) + inds = inds_x * W + inds_y + + results['inds_coarse'] = inds_coarse # need this when updating error_map + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + results['inds'] = inds + + else: + inds = torch.arange(H*W, device=device).expand([B, H*W]) + + zs = - torch.ones_like(i) + xs = - (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + # directions = safe_normalize(directions) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return results + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +class Trainer(object): + def __init__(self, + argv, # command line args + name, # name of this experiment + opt, # extra conf + model, # network + guidance, # guidance network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.argv = argv + self.name = name + self.opt = opt + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + self.model = model + + # guide model + self.guidance = guidance + self.embeddings = {} + + # text prompt / images + if self.guidance is not None: + for key in self.guidance: + for p in self.guidance[key].parameters(): + p.requires_grad = False + self.embeddings[key] = {} + self.prepare_embeddings() + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + if self.opt.images is not None: + self.pearson = PearsonCorrCoef().to(self.device) + + if optimizer is None: + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # variable init + self.total_train_t = 0 + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth" + os.makedirs(self.ckpt_path, exist_ok=True) + + # Save a copy of image_config in the experiment workspace + if opt.image_config is not None: + shutil.copyfile(opt.image_config, os.path.join(self.workspace, os.path.basename(opt.image_config))) + + # Save a copy of images in the experiment workspace + if opt.images is not None: + for image_file in opt.images: + shutil.copyfile(image_file, os.path.join(self.workspace, os.path.basename(image_file))) + + self.log(f'[INFO] Cmdline: {self.argv}') + self.log(f'[INFO] opt: {self.opt}') + self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "latest_model": + self.log("[INFO] Loading latest checkpoint (model only)...") + self.load_checkpoint(model_only=True) + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log(f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + # calculate the text embs. + @torch.no_grad() + def prepare_embeddings(self): + + # text embeddings (stable-diffusion) + if self.opt.text is not None: + + if 'SD' in self.guidance: + self.embeddings['SD']['default'] = self.guidance['SD'].get_text_embeds([self.opt.text]) + self.embeddings['SD']['uncond'] = self.guidance['SD'].get_text_embeds([self.opt.negative]) + + for d in ['front', 'side', 'back']: + self.embeddings['SD'][d] = self.guidance['SD'].get_text_embeds([f"{d} view {self.opt.text}"]) + + print('direction: ', d, 'prompt: ', f"{d} view {self.opt.text}") + + if 'IF' in self.guidance: + self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text]) + self.embeddings['IF']['uncond'] = self.guidance['IF'].get_text_embeds([self.opt.negative]) + + for d in ['front', 'side', 'back']: + self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([f"{d} view {self.opt.text}"]) + + if 'clip' in self.guidance: + self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text) + + if self.opt.images is not None: + + h = int(self.opt.known_view_scale * self.opt.h) + w = int(self.opt.known_view_scale * self.opt.w) + + # load processed image + for image in self.opt.images: + assert image.endswith('_rgba.png') # the rest of this code assumes that the _rgba image has been passed. + rgbas = [cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA) for image in self.opt.images] + rgba_hw = np.stack([cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas]) + rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:]) + self.rgb = torch.from_numpy(rgb_hw).permute(0,3,1,2).contiguous().to(self.device) + self.mask = torch.from_numpy(rgba_hw[..., 3] > 0.5).to(self.device) + print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}') + + # load depth + depth_paths = [image.replace('_rgba.png', '_depth.png') for image in self.opt.images] + depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths] + depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths]) + self.depth = torch.from_numpy(depth.astype(np.float32) / 255).to(self.device) # TODO: this should be mapped to FP16 + print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}') + + # load normal # TODO: don't load if normal loss is 0 + normal_paths = [image.replace('_rgba.png', '_normal.png') for image in self.opt.images] + normals = [cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) for normal_path in normal_paths] + normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals]) + self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device) + print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}') + + # encode embeddings for zero123 + if 'zero123' in self.guidance: + rgba_256 = np.stack([cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas]) + rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:]) + rgb_256 = torch.from_numpy(rgbs_256).permute(0,3,1,2).contiguous().to(self.device) + guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256) + self.embeddings['zero123']['default'] = { + 'zero123_ws' : self.opt.zero123_ws, + 'c_crossattn' : guidance_embeds[0], + 'c_concat' : guidance_embeds[1], + 'ref_polars' : self.opt.ref_polars, + 'ref_azimuths' : self.opt.ref_azimuths, + 'ref_radii' : self.opt.ref_radii, + } + + if 'clip' in self.guidance: + self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb) + + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + #print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + ### ------------------------------ + + def train_step(self, data, save_guidance_path:Path=None): + """ + Args: + save_guidance_path: an image that combines the NeRF render, the added latent noise, + the denoised result and optionally the fully-denoised image. + """ + + # perform RGBD loss instead of SDS if is image-conditioned + do_rgbd_loss = self.opt.images is not None and \ + (self.global_step % self.opt.known_view_interval == 0) + + # override random camera with fixed known camera + if do_rgbd_loss: + data = self.default_view_data + + # experiment iterations ratio + # i.e. what proportion of this experiment have we completed (in terms of iterations) so far? + exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / (self.opt.exp_end_iter - self.opt.exp_start_iter) + + # progressively relaxing view range + if self.opt.progressive_view: + r = min(1.0, self.opt.progressive_view_init_ratio + 2.0*exp_iter_ratio) + self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r, + self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r] + self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r, + self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r] + self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r, + self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r] + self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r, + self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r] + + # progressively increase max_level + if self.opt.progressive_level: + self.model.max_level = min(1.0, 0.25 + 2.0*exp_iter_ratio) + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + mvp = data['mvp'] # [B, 4, 4] + poses = data['poses'] # [B, 4, 4] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + # When ref_data has B images > opt.batch_size + if B > self.opt.batch_size: + # choose batch_size images out of those B images + choice = torch.randperm(B)[:self.opt.batch_size] + B = self.opt.batch_size + rays_o = rays_o[choice] + rays_d = rays_d[choice] + mvp = mvp[choice] + + if do_rgbd_loss: + ambient_ratio = 1.0 + shading = 'lambertian' # use lambertian instead of albedo to get normal + as_latent = False + binarize = False + bg_color = torch.rand((B * N, 3), device=rays_o.device) + + # add camera noise to avoid grid-like artifact + if self.opt.known_view_noise_scale > 0: + noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters) + rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale + + elif exp_iter_ratio <= self.opt.latent_iter_ratio: + ambient_ratio = 1.0 + shading = 'normal' + as_latent = True + binarize = False + bg_color = None + + else: + if exp_iter_ratio <= self.opt.albedo_iter_ratio: + ambient_ratio = 1.0 + shading = 'albedo' + else: + # random shading + ambient_ratio = self.opt.min_ambient_ratio + (1.0-self.opt.min_ambient_ratio) * random.random() + rand = random.random() + if rand >= (1.0 - self.opt.textureless_ratio): + shading = 'textureless' + else: + shading = 'lambertian' + + as_latent = False + + # random weights binarization (like mobile-nerf) [NOT WORKING NOW] + # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters) + # binarize = random.random() < binarize_thresh + binarize = False + + # random background + rand = random.random() + #if self.opt.bg_radius > 0 and rand > 0.5: + if self.opt.learnable_bg: + bg_color = None # use bg_net + else: + bg_color = torch.rand(3).to(self.device) # single color random bg + + outputs = self.model.render(rays_o, rays_d, poses, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize) + pred_depth = outputs['depth'].reshape(B, 1, H, W) + pred_mask = outputs['weights_sum'].reshape(B, 1, H, W) + if 'normal_image' in outputs: + pred_normal = outputs['normal_image'].reshape(B, H, W, 3) + + if as_latent: + # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D) + pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W] + else: + pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + + + # known view loss + if do_rgbd_loss: + gt_mask = self.mask # [B, H, W] + gt_rgb = self.rgb # [B, 3, H, W] + gt_normal = self.normal # [B, H, W, 3] + gt_depth = self.depth # [B, H, W] + + if len(gt_rgb) > self.opt.batch_size: + gt_mask = gt_mask[choice] + gt_rgb = gt_rgb[choice] + gt_normal = gt_normal[choice] + gt_depth = gt_depth[choice] + + # color loss + gt_rgb = gt_rgb * gt_mask[:, None].float() + bg_color.reshape(B, H, W, 3).permute(0,3,1,2).contiguous() * (1 - gt_mask[:, None].float()) + loss = self.opt.lambda_rgb * F.mse_loss(pred_rgb, gt_rgb) + + # mask loss + loss = loss + self.opt.lambda_mask * F.mse_loss(pred_mask[:, 0], gt_mask.float()) + + # normal loss + if self.opt.lambda_normal > 0 and 'normal_image' in outputs: + valid_gt_normal = 1 - 2 * gt_normal[gt_mask] # [B, 3] + valid_pred_normal = 2 * pred_normal[gt_mask] - 1 # [B, 3] + + lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters) + loss = loss + lambda_normal * (1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean()) + + # relative depth loss + if self.opt.lambda_depth > 0: + valid_gt_depth = gt_depth[gt_mask] # [B,] + valid_pred_depth = pred_depth[:, 0][gt_mask] # [B,] + lambda_depth = self.opt.lambda_depth * min(1, self.global_step / self.opt.iters) + loss = loss + lambda_depth * (1 - self.pearson(valid_pred_depth, valid_gt_depth)) + + # # scale-invariant + # with torch.no_grad(): + # A = torch.cat([valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1) # [B, 2] + # X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1] + # valid_gt_depth = A @ X # [B, 1] + # lambda_depth = self.opt.lambda_depth #* min(1, self.global_step / self.opt.iters) + # loss = loss + lambda_depth * F.mse_loss(valid_pred_depth, valid_gt_depth) + + # novel view loss + else: + + loss = 0 + + if 'SD' in self.guidance: + # interpolate text_z + azimuth = data['azimuth'] # [-180, 180] + + # ENHANCE: remove loop to handle batch size > 1 + text_z = [self.embeddings['SD']['uncond']] * azimuth.shape[0] + if self.opt.perpneg: + + text_z_comp, weights = adjust_text_embeddings(self.embeddings['SD'], azimuth, self.opt) + text_z.append(text_z_comp) + + else: + for b in range(azimuth.shape[0]): + if azimuth[b] >= -90 and azimuth[b] < 90: + if azimuth[b] >= 0: + r = 1 - azimuth[b] / 90 + else: + r = 1 + azimuth[b] / 90 + start_z = self.embeddings['SD']['front'] + end_z = self.embeddings['SD']['side'] + else: + if azimuth[b] >= 0: + r = 1 - (azimuth[b] - 90) / 90 + else: + r = 1 + (azimuth[b] + 90) / 90 + start_z = self.embeddings['SD']['side'] + end_z = self.embeddings['SD']['back'] + text_z.append(r * start_z + (1 - r) * end_z) + + text_z = torch.cat(text_z, dim=0) + if self.opt.perpneg: + loss = loss + self.guidance['SD'].train_step_perpneg(text_z, weights, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance, + save_guidance_path=save_guidance_path,) + else: + loss = loss + self.guidance['SD'].train_step(text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance, + save_guidance_path=save_guidance_path) + + if 'IF' in self.guidance: + # interpolate text_z + azimuth = data['azimuth'] # [-180, 180] + + # ENHANCE: remove loop to handle batch size > 1 + text_z = [self.embeddings['IF']['uncond']] * azimuth.shape[0] + if self.opt.perpneg: + text_z_comp, weights = adjust_text_embeddings(self.embeddings['IF'], azimuth, self.opt) + text_z.append(text_z_comp) + else: + for b in range(azimuth.shape[0]): + if azimuth[b] >= -90 and azimuth[b] < 90: + if azimuth[b] >= 0: + r = 1 - azimuth[b] / 90 + else: + r = 1 + azimuth[b] / 90 + start_z = self.embeddings['IF']['front'] + end_z = self.embeddings['IF']['side'] + else: + if azimuth[b] >= 0: + r = 1 - (azimuth[b] - 90) / 90 + else: + r = 1 + (azimuth[b] + 90) / 90 + start_z = self.embeddings['IF']['side'] + end_z = self.embeddings['IF']['back'] + text_z.append(r * start_z + (1 - r) * end_z) + + text_z = torch.cat(text_z, dim=0) + + if self.opt.perpneg: + loss = loss + self.guidance['IF'].train_step_perpneg(text_z, weights, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance) + else: + loss = loss + self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance) + + if 'zero123' in self.guidance: + + polar = data['polar'] + azimuth = data['azimuth'] + radius = data['radius'] + + loss = loss + self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale, + as_latent=as_latent, grad_scale=self.opt.lambda_guidance, save_guidance_path=save_guidance_path) + + if 'clip' in self.guidance: + + # empirical, far view should apply smaller CLIP loss + lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance + + loss = loss + self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance) + + # regularizations + if not self.opt.dmtet: + + if self.opt.lambda_opacity > 0: + loss_opacity = (outputs['weights_sum'] ** 2).mean() + loss = loss + self.opt.lambda_opacity * loss_opacity + + if self.opt.lambda_entropy > 0: + alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5) + # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 + loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean() + lambda_entropy = self.opt.lambda_entropy * min(1, 2 * self.global_step / self.opt.iters) + loss = loss + lambda_entropy * loss_entropy + + if self.opt.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs: + # pred_vals = outputs['normal_image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() + # smoothed_vals = TF.gaussian_blur(pred_vals.detach(), kernel_size=9) + # loss_smooth = F.mse_loss(pred_vals, smoothed_vals) + # total-variation + loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + \ + (pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :]).square().mean() + loss = loss + self.opt.lambda_2d_normal_smooth * loss_smooth + + if self.opt.lambda_orient > 0 and 'loss_orient' in outputs: + loss_orient = outputs['loss_orient'] + loss = loss + self.opt.lambda_orient * loss_orient + + if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs: + loss_normal_perturb = outputs['loss_normal_perturb'] + loss = loss + self.opt.lambda_3d_normal_smooth * loss_normal_perturb + + else: + + if self.opt.lambda_mesh_normal > 0: + loss = loss + self.opt.lambda_mesh_normal * outputs['normal_loss'] + + if self.opt.lambda_mesh_laplacian > 0: + loss = loss + self.opt.lambda_mesh_laplacian * outputs['lap_loss'] + + return pred_rgb, pred_depth, loss + + def post_train_step(self): + + # unscale grad before modifying it! + # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping + self.scaler.unscale_(self.optimizer) + + # clip grad + if self.opt.grad_clip >= 0: + torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip) + + # if not self.opt.dmtet and self.opt.backbone == 'grid': + # + # if self.opt.lambda_tv > 0: + # lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv + # self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound) + # if self.opt.lambda_wd > 0: + # self.model.encoder.grad_weight_decay(self.opt.lambda_wd) + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + poses = data['poses'] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + + outputs = self.model.render(rays_o, rays_d, poses, H, W, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading) + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + + # dummy + loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype) + + return pred_rgb, pred_depth, loss + + def test_step(self, data, bg_color=None, perturb=False): + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + poses = data['poses'] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + if bg_color is not None: + bg_color = bg_color.to(rays_o.device) + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + + outputs = self.model.render(rays_o, rays_d, poses, H, W, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color) + + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + + return pred_rgb, pred_depth, None + + def save_mesh(self, loader=None, save_path=None): + + if save_path is None: + save_path = os.path.join(self.workspace, 'mesh') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(save_path, exist_ok=True) + + self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, decimate_target=self.opt.decimate_target) + + self.log(f"==> Finished saving mesh.") + + ### ------------------------------ + + def train(self, train_loader, valid_loader, test_loader, max_epochs): + + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) + + start_t = time.time() + self.evaluate_one_epoch(valid_loader) + for epoch in range(self.epoch+1, max_epochs+1): + self.epoch = epoch + + self.train_one_epoch(train_loader, max_epochs) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.opt.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_checkpoint(full=False, best=True) + + if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs: + self.test(test_loader) + + end_t = time.time() + + self.total_train_t = end_t - start_t + self.total_train_t + + self.log(f"[INFO] training takes {(self.total_train_t)/ 60:.4f} minutes.") + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader, name=None): + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader, name) + self.use_tensorboardX = use_tensorboardX + + def test(self, loader, save_path=None, name=None, write_video=True): + + if save_path is None: + save_path = os.path.join(self.workspace, 'results') + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + os.makedirs(save_path, exist_ok=True) + + self.log(f"==> Start Test, save results to {save_path}") + + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + self.model.eval() + + if write_video: + all_preds = [] + all_preds_depth = [] + + with torch.no_grad(): + + for i, data in enumerate(loader): + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, _ = self.test_step(data) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6) + pred_depth = (pred_depth * 255).astype(np.uint8) + + if write_video: + all_preds.append(pred) + all_preds_depth.append(pred_depth) + else: + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth) + + pbar.update(loader.batch_size) + + if write_video: + all_preds = np.stack(all_preds, axis=0) + all_preds_depth = np.stack(all_preds_depth, axis=0) + print('save video...', os.path.join(save_path, f'{name}_rgb.mp4'), os.path.join(save_path, f'{name}_depth.mp4')) + imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) + imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1) + + self.log(f"==> Finished Test.") + + # [GUI] train text step. + def train_gui(self, train_loader, step=16): + + self.model.train() + + total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) + + loader = iter(train_loader) + + for _ in range(step): + + # mimic an infinite loop dataloader (in case the total dataset is smaller than step) + try: + data = next(loader) + except StopIteration: + loader = iter(train_loader) + data = next(loader) + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + pred_rgbs, pred_depths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.post_train_step() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + total_loss += loss.detach() + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss.item() / step + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + outputs = { + 'loss': average_loss, + 'lr': self.optimizer.param_groups[0]['lr'], + } + + return outputs + + + # [GUI] test on a single image + def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'): + + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + + pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) + mvp = torch.from_numpy(mvp).unsqueeze(0).to(self.device) + + rays = get_rays(pose, intrinsics, rH, rW, -1) + + # from degree theta/phi to 3D normalized vec + light_d = np.deg2rad(light_d) + light_d = np.array([ + np.sin(light_d[0]) * np.sin(light_d[1]), + np.cos(light_d[0]), + np.sin(light_d[0]) * np.cos(light_d[1]), + ], dtype=np.float32) + light_d = torch.from_numpy(light_d).to(self.device) + + data = { + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'mvp': mvp, + 'H': rH, + 'W': rW, + 'light_d': light_d, + 'ambient_ratio': ambient_ratio, + 'shading': shading, + } + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! + preds, preds_depth, _ = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp) + + if self.ema is not None: + self.ema.restore() + + # interpolation to the original resolution + if downscale != 1: + # have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + outputs = { + 'image': preds[0].detach().cpu().numpy(), + 'depth': preds_depth[0].detach().cpu().numpy(), + } + + return outputs + + def train_one_epoch(self, loader, max_epochs): + self.log(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + if self.opt.save_guidance: + save_guidance_folder = Path(self.workspace) / 'guidance' + save_guidance_folder.mkdir(parents=True, exist_ok=True) + + for data in loader: + + # update grid every 16 steps + if (self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + if self.opt.save_guidance and (self.global_step % self.opt.save_guidance_interval == 0 or self.global_step == 1): + save_guidance_path = save_guidance_folder / f'step_{self.global_step:07d}.png' + else: + save_guidance_path = None + pred_rgbs, pred_depths, loss = self.train_step(data, save_guidance_path=save_guidance_path) + + # hooked grad clipping for RGB space + if self.opt.grad_clip_rgb >= 0: + def _hook(grad): + if self.opt.fp16: + # correctly handle the scale + grad_scale = self.scaler._get_scale_async() + return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb) + else: + return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb) + pred_rgbs.register_hook(_hook) + # pred_rgbs.retain_grad() + + self.scaler.scale(loss).backward() + + self.post_train_step() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.local_rank == 0: + # if self.report_metric_at_train: + # for metric in self.metrics: + # metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + cpu_mem, gpu_mem = get_CPU_mem(), get_GPU_mem()[0] + self.log(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Finished Epoch {self.epoch}/{max_epochs}. CPU={cpu_mem:.1f}GB, GPU={gpu_mem:.1f}GB.") + + + def evaluate_one_epoch(self, loader, name=None): + self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...") + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + + for data in loader: + self.local_step += 1 + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, loss = self.eval_step(data) + + # all_gather/reduce the statistics (NCCL only support all_*) + if self.world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss = loss / self.world_size + + preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_list, preds) + preds = torch.cat(preds_list, dim=0) + + preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_depth_list, preds_depth) + preds_depth = torch.cat(preds_depth_list, dim=0) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + + # save image + save_path = os.path.join(self.workspace, 'validation', f'{name}_rgb_{self.local_step:04d}.png') + save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_depth_{self.local_step:04d}.png') + + #self.log(f"==> Saving validation image to {save_path}") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6) + pred_depth = (pred_depth * 255).astype(np.uint8) + + cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(save_path_depth, pred_depth) + + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result + else: + self.stats["results"].append(average_loss) # if no metric, choose best by min loss + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + if self.ema is not None: + self.ema.restore() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, name=None, full=False, best=False): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + } + + if self.model.cuda_ray: + state['mean_density'] = self.model.mean_density + + if self.opt.dmtet: + state['tet_scale'] = self.model.tet_scale.cpu().numpy() + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{name}.pth" + + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0)) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, os.path.join(self.ckpt_path, file_path)) + + else: + if len(self.stats["results"]) > 0: + # always save best since loss cannot reflect performance. + if True: + # self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + # self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + try: + self.ema.load_state_dict(checkpoint_dict['ema']) + self.log("[INFO] loaded EMA.") + except: + self.log("[WARN] failed to loaded EMA.") + + if self.model.cuda_ray: + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + + if self.opt.dmtet: + if 'tet_scale' in checkpoint_dict: + new_scale = torch.from_numpy(checkpoint_dict['tet_scale']).to(self.device) + self.model.verts *= new_scale / self.model.tet_scale + self.model.tet_scale = new_scale + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") + + +def get_CPU_mem(): + return psutil.Process(os.getpid()).memory_info().rss /1024**3 + + +def get_GPU_mem(): + num = torch.cuda.device_count() + mem, mems = 0, [] + for i in range(num): + mem_free, mem_total = torch.cuda.mem_get_info(i) + mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000) + mem += mems[-1] + return mem, mems diff --git a/stable-dreamfusion-3DPortrait/optimizer.py b/stable-dreamfusion-3DPortrait/optimizer.py new file mode 100644 index 0000000..f5bb64f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/optimizer.py @@ -0,0 +1,325 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. (default: True) + """ + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + + clip_global_grad_norm = torch.clamp( + max_grad_norm / (global_grad_norm + group['eps']), + max=1.0).item() + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + neg_pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'neg_pre_grad' not in state or group['step'] == 1: + state['neg_pre_grad'] = p.grad.clone().mul_( + -clip_global_grad_norm) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + neg_pre_grads.append(state['neg_pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + neg_pre_grads=neg_pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + + if group['foreach']: + _multi_tensor_adan(**kwargs) + else: + _single_tensor_adan(**kwargs) + + return loss + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + neg_grad_or_diff = neg_pre_grads[i] + + grad.mul_(clip_global_grad_norm) + + # for memory saving, we use `neg_grad_or_diff` + # to get some temp variable in a inplace way + neg_grad_or_diff.add_(grad) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, + alpha=1 - beta2) # diff_t + + neg_grad_or_diff.mul_(beta2).add_(grad) + exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, + neg_grad_or_diff, + value=1 - beta3) # n_t + + denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + param.div_(1 + lr * weight_decay) + + neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if len(params) == 0: + return + + torch._foreach_mul_(grads, clip_global_grad_norm) + + # for memory saving, we use `neg_pre_grads` + # to get some temp variable in a inplace way + torch._foreach_add_(neg_pre_grads, grads) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, neg_pre_grads, + alpha=1 - beta2) # diff_t + + torch._foreach_mul_(neg_pre_grads, beta2) + torch._foreach_add_(neg_pre_grads, grads) + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_(exp_avg_sqs, + neg_pre_grads, + neg_pre_grads, + value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + else: + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + torch._foreach_div_(params, 1 + lr * weight_decay) + torch._foreach_zero_(neg_pre_grads) + torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/portrait3d_main.py b/stable-dreamfusion-3DPortrait/portrait3d_main.py new file mode 100644 index 0000000..7248be9 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/portrait3d_main.py @@ -0,0 +1,72 @@ +import os + +import glob +import random +import argparse +# +parser = argparse.ArgumentParser() +parser.add_argument('--trigrid_decoder_ckpt', type=str) +parser.add_argument('--inversion_name', type=str) +parser.add_argument('--network_path', type=str) +parser.add_argument('--test_data_dir', type=str,default='../test_data') +parser.add_argument('--df_ckpt', type=str,default='SG161222/Realistic_Vision_V5.1_noVAE') + +opt = parser.parse_args() +trigrid_decoder_ckpt = opt.trigrid_decoder_ckpt +inversion_name = opt.inversion_name +network_path = opt.network_path +test_data_dir = opt.test_data_dir +df_ckpt = opt.df_ckpt + +exp_name = 'text_to_3dportrait' + +# the current file's path +root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +print(root) + +todo = glob.glob(os.path.join(test_data_dir, '*/prompt.txt')) +for prompt_file in todo: + + with open(prompt_file, 'r') as f: + prompt = f.read() + + prompt = prompt.replace('\n', '') + + dir_ = os.path.dirname(prompt_file) + name = dir_.split('/')[-1].split('\\')[-1] + + if os.path.exists(f'output/{exp_name}/{name}/results_final/small_pose_final.mp4'): + continue + trigrid_list = glob.glob(f'{dir_}/samples_new_crop/{inversion_name}/*/inversion_trigrid.pkl') + if len(trigrid_list) == 0: + continue + inversion_trigrid = trigrid_list[0] + + + # change dir + os.chdir(os.path.join(root, 'stable-dreamfusion-3DPortrait')) + cmd = f'python main_3DPortraitGAN_cam.py --workspace output/{exp_name}/{name} --latent_iter_ratio 0 --trigrid_lr_ratio 200 200 200 200 200 40 20 --t_range 0.02 0.4 --vram_O --w 128 --h 128 --H 512 --W 512 --iters 2000 --text "{prompt}" --hf_key {df_ckpt} --trigrid_path {inversion_trigrid} --trigrid_decoder_ckpt {trigrid_decoder_ckpt}' + print(cmd) + os.system(cmd) + + os.chdir(os.path.join(root, '3DPortraitGAN_pyramid')) + cmd = f'python gen_quality_improve_data_from_triplane.py --data_dir={root}/stable-dreamfusion-3DPortrait/output/{exp_name}/{name} --grid=1x1 --network={network_path}' + print(cmd) + os.system(cmd) + + os.chdir(os.path.join(root, 'stable-dreamfusion-3DPortrait')) + cmd = f'python guidance/sdedit.py --data_dir {root}/stable-dreamfusion-3DPortrait/output/{exp_name}/{name} --hf_key {df_ckpt} -H 512 -W 512 --seed 42 --test_data_dir={test_data_dir}' + print(cmd) + os.system(cmd) + + os.chdir(os.path.join(root, '3DPortraitGAN_pyramid')) + cmd = f'python optimize_trigrid.py --data_dir={root}/stable-dreamfusion-3DPortrait/output/{exp_name}/{name} --grid=1x1 --network={network_path}' + print(cmd) + os.system(cmd) + + os.chdir(os.path.join(root, '3DPortraitGAN_pyramid')) + cmd = f'python gen_videos_shapes_from_optimized_triplane.py --data_dir={root}/stable-dreamfusion-3DPortrait/output/{exp_name}/{name} --grid=1x1 --network={network_path}' + print(cmd) + os.system(cmd) + + diff --git a/stable-dreamfusion-3DPortrait/preprocess_image.py b/stable-dreamfusion-3DPortrait/preprocess_image.py new file mode 100644 index 0000000..f7937b2 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/preprocess_image.py @@ -0,0 +1,203 @@ +import os +import sys +import cv2 +import argparse +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image + +class BackgroundRemoval(): + def __init__(self, device='cuda'): + + from carvekit.api.high import HiInterface + self.interface = HiInterface( + object_type="object", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device=device, + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=True, + ) + + @torch.no_grad() + def __call__(self, image): + # image: [H, W, 3] array in [0, 255]. + image = Image.fromarray(image) + + image = self.interface([image])[0] + image = np.array(image) + + return image + +class BLIP2(): + def __init__(self, device='cuda'): + self.device = device + from transformers import AutoProcessor, Blip2ForConditionalGeneration + self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) + + @torch.no_grad() + def __call__(self, image): + image = Image.fromarray(image) + inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) + + generated_ids = self.model.generate(**inputs, max_new_tokens=20) + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + + return generated_text + + +class DPT(): + def __init__(self, task='depth', device='cuda'): + + self.task = task + self.device = device + + from dpt import DPTDepthModel + + if task == 'depth': + path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt' + self.model = DPTDepthModel(backbone='vitb_rn50_384') + self.aug = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize(mean=0.5, std=0.5) + ]) + + else: # normal + path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt' + self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) + self.aug = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor() + ]) + + # load model + checkpoint = torch.load(path, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = {} + for k, v in checkpoint['state_dict'].items(): + state_dict[k[6:]] = v + else: + state_dict = checkpoint + self.model.load_state_dict(state_dict) + self.model.eval().to(device) + + + @torch.no_grad() + def __call__(self, image): + # image: np.ndarray, uint8, [H, W, 3] + H, W = image.shape[:2] + image = Image.fromarray(image) + + image = self.aug(image).unsqueeze(0).to(self.device) + + if self.task == 'depth': + depth = self.model(image).clamp(0, 1) + depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False) + depth = depth.squeeze(1).cpu().numpy() + return depth + else: + normal = self.model(image).clamp(0, 1) + normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False) + normal = normal.cpu().numpy() + return normal + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--size', default=256, type=int, help="output resolution") + parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") + parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123") + parser.add_argument('--dont_recenter', dest='recenter', action='store_false') + opt = parser.parse_args() + + out_dir = os.path.dirname(opt.path) + out_rgba = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_rgba.png') + out_depth = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_depth.png') + out_normal = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_normal.png') + out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt') + + # load image + print(f'[INFO] loading image...') + image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED) + if image.shape[-1] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + else: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # carve background + print(f'[INFO] background removal...') + carved_image = BackgroundRemoval()(image) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # predict depth + print(f'[INFO] depth estimation...') + dpt_depth_model = DPT(task='depth') + depth = dpt_depth_model(image)[0] + depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9) + depth[~mask] = 0 + depth = (depth * 255).astype(np.uint8) + del dpt_depth_model + + # predict normal + print(f'[INFO] normal estimation...') + dpt_normal_model = DPT(task='normal') + normal = dpt_normal_model(image)[0] + normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0) + normal[~mask] = 0 + del dpt_normal_model + + # recenter + if opt.recenter: + print(f'[INFO] recenter...') + final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) + final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8) + final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8) + + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + h = x_max - x_min + w = y_max - y_min + desired_size = int(opt.size * (1 - opt.border_ratio)) + scale = desired_size / max(h, w) + h2 = int(h * scale) + w2 = int(w * scale) + x2_min = (opt.size - h2) // 2 + x2_max = x2_min + h2 + y2_min = (opt.size - w2) // 2 + y2_max = y2_min + w2 + final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + + else: + final_rgba = carved_image + final_depth = depth + final_normal = normal + + # write output + cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA)) + cv2.imwrite(out_depth, final_depth) + cv2.imwrite(out_normal, final_normal) + + # predict caption (it's too slow... use your brain instead) + # print(f'[INFO] captioning...') + # blip2 = BLIP2() + # caption = blip2(image) + # with open(out_caption, 'w') as f: + # f.write(caption) + diff --git a/stable-dreamfusion-3DPortrait/quality_improve.py b/stable-dreamfusion-3DPortrait/quality_improve.py new file mode 100644 index 0000000..5be2e46 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/quality_improve.py @@ -0,0 +1,27 @@ +# 190,140,22,129,104,113,133,164,15,31,72,135,83,149,85,169 +for name in [91,111,96,53,143]: + + print('cd F:/high_quality_3DPortraitGAN/exp/3DPortraitGAN-hierarchy-v2') + print('activate 3dportraitgan') + + cmd = f'python gen_quality_improve_data_from_triplane.py --data_dir=F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion-hierarchy-v2/output/better_direction_prompt/{name} --grid=1x1 --network=F:/high_quality_3DPortraitGAN/exp/3DPortraitGAN-hierarchy-v2/models/model.pkl' + + print(cmd) + + + print('cd F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion-hierarchy-v2') + print('activate ldm_3dgan_kaolin') + + cmd = f'python guidance/sdedit.py --data_dir F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion-hierarchy-v2/output/better_direction_prompt/{name} --hf_key F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion/pretrained/SG161222Realistic_Vision_V5.1_noVAE -H 512 -W 512 --seed 42' + + print(cmd) + + + print('cd F:/high_quality_3DPortraitGAN/exp/3DPortraitGAN-hierarchy-v2') + print('activate 3dportraitgan') + cmd = f'python optimize_trigrid.py --data_dir=F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion-hierarchy-v2/output/better_direction_prompt/{name} --grid=1x1 --network=F:/high_quality_3DPortraitGAN/exp/3DPortraitGAN-hierarchy-v2/models/model.pkl' + print(cmd) + + + cmd = f'python gen_videos_shapes_from_optimized_triplane.py --data_dir=F:/high_quality_3DPortraitGAN/exp/stable-dreamfusion-hierarchy-v2/output/better_direction_prompt/{name} --grid=1x1 --network=F:/high_quality_3DPortraitGAN/exp/3DPortraitGAN-hierarchy-v2/models/model.pkl' + print(cmd) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/__init__.py b/stable-dreamfusion-3DPortrait/raymarching/__init__.py new file mode 100644 index 0000000..26d3cc6 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/backend.py b/stable-dreamfusion-3DPortrait/raymarching/backend.py new file mode 100644 index 0000000..7cc0d76 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/raymarching.py b/stable-dreamfusion-3DPortrait/raymarching/raymarching.py new file mode 100644 index 0000000..760d730 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/raymarching.py @@ -0,0 +1,398 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +# lazy building: +# `import raymarching` will not immediately build the extension, only if you actually call any functions. + +BACKEND = None + +def get_backend(): + global BACKEND + + if BACKEND is None: + try: + import _raymarching as _backend + except ImportError: + from .backend import _backend + + BACKEND = _backend + + return BACKEND + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + get_backend().morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + get_backend().morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + get_backend().packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + + +class _flatten_rays(Function): + @staticmethod + def forward(ctx, rays, M): + ''' flatten rays + Args: + rays: [N, 2], all rays' (point_offset, point_count), + M: scalar, int, count of points (we cannot get this info from rays unfortunately...) + Returns: + res: [M], flattened ray index. + ''' + if not rays.is_cuda: rays = rays.cuda() + rays = rays.contiguous() + + N = rays.shape[0] + + res = torch.zeros(M, dtype=torch.int, device=rays.device) + + get_backend().flatten_rays(rays, N, M, res) + + return res + +flatten_rays = _flatten_rays.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024, contract=False): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + ts: float, [M, 2], all generated points' ts. + rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.float().contiguous().view(-1, 3) + rays_d = rays_d.float().contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + + step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + # first pass: write rays, get total number of points M to render + rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises) + + # allocate based on M + M = step_counter.item() + # print(M, N) + # print(rays[:, 0].max()) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + + # second pass: write outputs + get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises) + + return xyzs, dirs, ts, rays + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ts: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights: float, [M] + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.float().contiguous() + rgbs = rgbs.float().contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0 + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + get_backend().composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image) + ctx.dims = [M, N, T_thresh, binarize] + + return weights, weights_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image): + + grad_weights = grad_weights.contiguous() + grad_weights_sum = grad_weights_sum.contiguous() + grad_depth = grad_depth.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors + M, N, T_thresh, binarize = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + + get_backend().composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, binarize, grad_sigmas, grad_rgbs) + + return grad_sigmas, grad_rgbs, None, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024, contract=False): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + ts: float, [n_alive * n_step, 2], all generated points' ts + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.float().contiguous().view(-1, 3) + rays_d = rays_d.float().contiguous().view(-1, 3) + + M = n_alive * n_step + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises) + + return xyzs, dirs, ts + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, binarize=False): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + ts: float, [n_alive * n_step, 2] + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + sigmas = sigmas.float().contiguous() + rgbs = rgbs.float().contiguous() + get_backend().composite_rays(n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/setup.py b/stable-dreamfusion-3DPortrait/raymarching/setup.py new file mode 100644 index 0000000..4d32fa7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/setup.py @@ -0,0 +1,63 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/src/bindings.cpp b/stable-dreamfusion-3DPortrait/raymarching/src/bindings.cpp new file mode 100644 index 0000000..eb8f122 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/src/bindings.cpp @@ -0,0 +1,20 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)"); + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.cu b/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.cu new file mode 100644 index 0000000..0292f1c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.cu @@ -0,0 +1,934 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + // const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + + +__global__ void kernel_flatten_rays( + const int * __restrict__ rays, + const uint32_t N, const uint32_t M, + int * res +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t offset = rays[n * 2]; + uint32_t num_steps = rays[n * 2 + 1]; + + // write to res + res += offset; + for (int i = 0; i < num_steps; i++) res[i] = n; +} + +void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res) { + + static constexpr uint32_t N_THREAD = 128; + + kernel_flatten_rays<<>>(rays.data_ptr(), N, M, res.data_ptr()); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, ts: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, const bool contract, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * ts, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // is first pass running. + const bool first_pass = (xyzs == nullptr); + + // locate + rays_o += n * 3; + rays_d += n * 3; + rays += n * 2; + + uint32_t num_steps = max_steps; + + if (!first_pass) { + uint32_t point_index = rays[0]; + num_steps = rays[1]; + xyzs += point_index * 3; + dirs += point_index * 3; + ts += point_index * 2; + } + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * bound / H; + // const float dt_max = 1e10f; + + float t0 = near; + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + float t = t0; + uint32_t step = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // contraction + float cx = x, cy = y, cz = z; + const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z))); + if (contract && mag > 1) { + // L-INF norm + const float Linf_scale = (2 - 1 / mag) / mag; + cx *= Linf_scale; + cy *= Linf_scale; + cz *= Linf_scale; + } + + // convert to nearest grid position + const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, step); + + if (occ) { + step++; + t += dt; + if (!first_pass) { + xyzs[0] = cx; // write contracted coordinates! + xyzs[1] = cy; + xyzs[2] = cz; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + ts[0] = t; + ts[1] = dt; + xyzs += 3; + dirs += 3; + ts += 2; + } + // contraction case: cannot apply voxel skipping. + } else if (contract && mag > 1) { + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + dt = clamp(t * dt_gamma, dt_min, dt_max); + t += dt; + } while (t < tt); + } + } + + //printf("[n=%d] step=%d, near=%f, far=%f, dt=%f, num_steps=%f\n", n, step, near, far, dt_min, (far - near) / dt_min); + + // write rays + if (first_pass) { + uint32_t point_index = atomicAdd(counter, step); + rays[0] = point_index; + rays[1] = step; + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, contract, dt_gamma, max_steps, N, C, H, nears.data_ptr(), fars.data_ptr(), + xyzs.has_value() ? xyzs.value().data_ptr() : nullptr, + dirs.has_value() ? dirs.value().data_ptr() : nullptr, + ts.has_value() ? ts.value().data_ptr() : nullptr, + rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// ts: [M, 2] +// rays: [N, 2], offset, num_steps +// weights: [M] +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ts, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, + scalar_t * weights, + scalar_t * weights_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t offset = rays[n * 2]; + uint32_t num_steps = rays[n * 2 + 1]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[n] = 0; + depth[n] = 0; + image[n * 3] = 0; + image[n * 3 + 1] = 0; + image[n * 3 + 2] = 0; + return; + } + + ts += offset * 2; + weights += offset; + sigmas += offset; + rgbs += offset * 3; + + // accumulate + uint32_t step = 0; + + float T = 1.0f; + float r = 0, g = 0, b = 0, ws = 0, d = 0; + + while (step < num_steps) { + + const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]); + const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha; + const float weight = alpha * T; + + weights[0] = weight; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + ws += weight; + d += weight * ts[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + weights++; + sigmas++; + rgbs += 3; + ts += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[n] = ws; // weights_sum + depth[n] = d; + image[n * 3] = r; + image[n * 3 + 1] = g; + image[n * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), M, N, T_thresh, binarize, weights.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights: [M,] +// grad_weights_sum: [N,] +// grad_image: [N, 3] +// grad_depth: [N,] +// sigmas: [M] +// rgbs: [M, 3] +// ts: [M, 2] +// rays: [N, 2], offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights, + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_depth, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ts, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ depth, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t offset = rays[n * 2]; + uint32_t num_steps = rays[n * 2 + 1]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights += offset; + grad_weights_sum += n; + grad_depth += n; + grad_image += n * 3; + weights_sum += n; + depth += n; + image += n * 3; + sigmas += offset; + rgbs += offset * 3; + ts += offset * 2; + grad_sigmas += offset; + grad_rgbs += offset * 3; + + // accumulate + uint32_t step = 0; + + float T = 1.0f; + const float r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], d_final = depth[0]; + float r = 0, g = 0, b = 0, ws = 0, d = 0; + + while (step < num_steps) { + + const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]); + const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha; + const float weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + ws += weight; + d += weight * ts[0]; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_sigmas + grad_sigmas[0] = ts[1] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + (grad_weights_sum[0] + grad_weights[0]) * (T - (ws_final - ws)) + + grad_depth[0] * (T * ts[0] - (d_final - d)) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + ts += 2; + grad_weights++; + grad_sigmas++; + grad_rgbs += 3; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights.data_ptr(), grad_weights_sum.data_ptr(), grad_depth.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr(), M, N, T_thresh, binarize, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, const bool contract, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* ts, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + ts += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[index], far = fars[index]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * bound / H; + // const float dt_max = 1e10f; + + // march for n_step steps, record points + float t = rays_t[index]; + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + uint32_t step = 0; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // contraction + float cx = x, cy = y, cz = z; + const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z))); + if (contract && mag > 1) { + // L-INF norm + const float Linf_scale = (2 - 1 / mag) / mag; + cx *= Linf_scale; + cy *= Linf_scale; + cz *= Linf_scale; + } + + // convert to nearest grid position + const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = cx; + xyzs[1] = cy; + xyzs[2] = cz; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + ts[0] = t; + ts[1] = dt; + // step + xyzs += 3; + dirs += 3; + ts += 2; + step++; + + // contraction case + } else if (contract && mag > 1) { + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + dt = clamp(t * dt_gamma, dt_min, dt_max); + t += dt; + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, contract, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), ts.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, const bool binarize, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ ts, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + ts += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + float t; + float d = depth[0], r = image[0], g = image[1], b = image[2], weight_sum = weights_sum[0]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if t == 0 + if (ts[0] == 0) break; + + const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]); + const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha; + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const float T = 1 - weight_sum; + const float weight = alpha * T; + weight_sum += weight; + + t = ts[0]; + d += weight * t; // real depth + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + ts += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, binarize, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.h b/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.h new file mode 100644 index 0000000..a9994d3 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/raymarching/src/raymarching.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); +void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/readme.md b/stable-dreamfusion-3DPortrait/readme.md new file mode 100644 index 0000000..1d9068d --- /dev/null +++ b/stable-dreamfusion-3DPortrait/readme.md @@ -0,0 +1,356 @@ +# Stable-Dreamfusion + +A pytorch implementation of the text-to-3D model **Dreamfusion**, powered by the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) text-to-2D model. + +**ADVERTISEMENT: Please check out [threestudio](https://github.com/threestudio-project/threestudio) for recent improvements and better implementation in 3D content generation!** + +**NEWS (2023.6.12)**: + +* Support of [Perp-Neg](https://perp-neg.github.io/) to alleviate multi-head problem in Text-to-3D. +* Support of Perp-Neg for both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and [DeepFloyd-IF](https://github.com/deep-floyd/IF). + +https://user-images.githubusercontent.com/25863658/236712982-9f93bd32-83bf-423a-bb7c-f73df7ece2e3.mp4 + +https://user-images.githubusercontent.com/25863658/232403162-51b69000-a242-4b8c-9cd9-4242b09863fa.mp4 + +### [Update Logs](assets/update_logs.md) + +### Colab notebooks: +* Instant-NGP backbone (`-O`): [![Instant-NGP Backbone](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MXT3yfOFvO0ooKEfiUUvTKwUkrrlCHpF?usp=sharing) + +* Vanilla NeRF backbone (`-O2`): [![Vanilla Backbone](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mvfxG-S_n_gZafWoattku7rLJ2kPoImL?usp=sharing) + +# Important Notice +This project is a **work-in-progress**, and contains lots of differences from the paper. **The current generation quality cannot match the results from the original paper, and many prompts still fail badly!** + +## Notable differences from the paper +* Since the Imagen model is not publicly available, we use [Stable Diffusion](https://github.com/CompVis/stable-diffusion) to replace it (implementation from [diffusers](https://github.com/huggingface/diffusers)). Different from Imagen, Stable-Diffusion is a latent diffusion model, which diffuses in a latent space instead of the original image space. Therefore, we need the loss to propagate back from the VAE's encoder part too, which introduces extra time cost in training. +* We use the [multi-resolution grid encoder](https://github.com/NVlabs/instant-ngp/) to implement the NeRF backbone (implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp)), which enables much faster rendering (~10FPS at 800x800). +* We use the [Adan](https://github.com/sail-sg/Adan) optimizer as default. + +# Install + +```bash +git clone https://github.com/ashawkey/stable-dreamfusion.git +cd stable-dreamfusion +``` + +### Optional: create a python virtual environment + +To avoid python package conflicts, we recommend using a virtual environment, e.g.: using conda or venv: + +```bash +python -m venv venv_stable-dreamfusion +source venv_stable-dreamfusion/bin/activate # you need to repeat this step for every new terminal +``` + +### Install with pip + +```bash +pip install -r requirements.txt +``` + +### Download pre-trained models + +To use image-conditioned 3D generation, you need to download some pretrained checkpoints manually: +* [Zero-1-to-3](https://github.com/cvlab-columbia/zero123) for diffusion backend. + We use `zero123-xl.ckpt` by default, and it is hard-coded in `guidance/zero123_utils.py`. + ```bash + cd pretrained/zero123 + wget https://zero123.cs.columbia.edu/assets/zero123-xl.ckpt + ``` +* [Omnidata](https://github.com/EPFL-VILAB/omnidata/tree/main/omnidata_tools/torch) for depth and normal prediction. + These ckpts are hardcoded in `preprocess_image.py`. + ```bash + mkdir pretrained/omnidata + cd pretrained/omnidata + # assume gdown is installed + gdown '1Jrh-bRnJEjyMCS7f-WsaFlccfPjJPPHI&confirm=t' # omnidata_dpt_depth_v2.ckpt + gdown '1wNxVO4vVbDEMEpnAi_jwQObf2MFodcBR&confirm=t' # omnidata_dpt_normal_v2.ckpt + ``` + +To use [DeepFloyd-IF](https://github.com/deep-floyd/IF), you need to accept the usage conditions from [hugging face](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0), and login with `huggingface-cli login` in command line. + +For DMTet, we port the pre-generated `32/64/128` resolution tetrahedron grids under `tets`. +The 256 resolution one can be found [here](https://drive.google.com/file/d/1lgvEKNdsbW5RS4gVxJbgBS4Ac92moGSa/view?usp=sharing). + +### Build extension (optional) +By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. +We also provide the `setup.py` to build each extension: +```bash +cd stable-dreamfusion + +# install all extension modules +bash scripts/install_ext.sh + +# if you want to install manually, here is an example: +pip install ./raymarching # install to python path (you still need the raymarching/ folder, since this only installs the built extension.) +``` + +### Taichi backend (optional) +Use [Taichi](https://github.com/taichi-dev/taichi) backend for Instant-NGP. It achieves comparable performance to CUDA implementation while **No CUDA** build is required. Install Taichi with pip: +```bash +pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly +``` + +### Trouble Shooting: +* we assume working with the latest version of all dependencies, if you meet any problems from a specific dependency, please try to upgrade it first (e.g., `pip install -U diffusers`). If the problem still holds, [reporting a bug issue](https://github.com/ashawkey/stable-dreamfusion/issues/new?assignees=&labels=bug&template=bug_report.yaml&title=%3Ctitle%3E) will be appreciated! +* `[F glutil.cpp:338] eglInitialize() failed Aborted (core dumped)`: this usually indicates problems in OpenGL installation. Try to re-install Nvidia driver, or use nvidia-docker as suggested in https://github.com/ashawkey/stable-dreamfusion/issues/131 if you are using a headless server. +* `TypeError: xxx_forward(): incompatible function arguments`: this happens when we update the CUDA source and you used `setup.py` to install the extensions earlier. Try to re-install the corresponding extension (e.g., `pip install ./gridencoder`). + +### Tested environments +* Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100. + +# Usage + +First time running will take some time to compile the CUDA extensions. + +```bash +#### stable-dreamfusion setting + +### Instant-NGP NeRF Backbone +# + faster rendering speed +# + less GPU memory (~16G) +# - need to build CUDA extensions (a CUDA-free Taichi backend is available) + +## train with text prompt (with the default settings) +# `-O` equals `--cuda_ray --fp16` +# `--cuda_ray` enables instant-ngp-like occupancy grid based acceleration. +python main.py --text "a hamburger" --workspace trial -O + +# reduce stable-diffusion memory usage with `--vram_O` +# enable various vram savings (https://huggingface.co/docs/diffusers/optimization/fp16). +python main.py --text "a hamburger" --workspace trial -O --vram_O + +# You can collect arguments in a file. You can override arguments by specifying them after `--file`. Note that quoted strings can't be loaded from .args files... +python main.py --file scripts/res64.args --workspace trial_awesome_hamburger --text "a photo of an awesome hamburger" + +# use CUDA-free Taichi backend with `--backbone grid_taichi` +python3 main.py --text "a hamburger" --workspace trial -O --backbone grid_taichi + +# choose stable-diffusion version (support 1.5, 2.0 and 2.1, default is 2.1 now) +python main.py --text "a hamburger" --workspace trial -O --sd_version 1.5 + +# use a custom stable-diffusion checkpoint from hugging face: +python main.py --text "a hamburger" --workspace trial -O --hf_key andite/anything-v4.0 + +# use DeepFloyd-IF for guidance (experimental): +python main.py --text "a hamburger" --workspace trial -O --IF +python main.py --text "a hamburger" --workspace trial -O --IF --vram_O # requires ~24G GPU memory + +# we also support negative text prompt now: +python main.py --text "a rose" --negative "red" --workspace trial -O + +## after the training is finished: +# test (exporting 360 degree video) +python main.py --workspace trial -O --test +# also save a mesh (with obj, mtl, and png texture) +python main.py --workspace trial -O --test --save_mesh +# test with a GUI (free view control!) +python main.py --workspace trial -O --test --gui + +### Vanilla NeRF backbone +# + pure pytorch, no need to build extensions! +# - slow rendering speed +# - more GPU memory + +## train +# `-O2` equals `--backbone vanilla` +python main.py --text "a hotdog" --workspace trial2 -O2 + +# if CUDA OOM, try to reduce NeRF sampling steps (--num_steps and --upsample_steps) +python main.py --text "a hotdog" --workspace trial2 -O2 --num_steps 64 --upsample_steps 0 + +## test +python main.py --workspace trial2 -O2 --test +python main.py --workspace trial2 -O2 --test --save_mesh +python main.py --workspace trial2 -O2 --test --gui # not recommended, FPS will be low. + +### DMTet finetuning + +## use --dmtet and --init_with to finetune the mesh at higher reslution +python main.py -O --text "a hamburger" --workspace trial_dmtet --dmtet --iters 5000 --init_with trial/checkpoints/df.pth + +## init dmtet with a mesh to generate texture +# require install of cubvh: pip install git+https://github.com/ashawkey/cubvh +# remove --lock_geo to also finetune geometry, but performance may be bad. +python main.py -O --text "a white bunny with red eyes" --workspace trial_dmtet_mesh --dmtet --iters 5000 --init_with ./data/bunny.obj --lock_geo + +## test & export the mesh +python main.py -O --text "a hamburger" --workspace trial_dmtet --dmtet --iters 5000 --test --save_mesh + +## gui to visualize dmtet +python main.py -O --text "a hamburger" --workspace trial_dmtet --dmtet --iters 5000 --test --gui + +### Image-conditioned 3D Generation + +## preprocess input image +# note: the results of image-to-3D is dependent on zero-1-to-3's capability. For best performance, the input image should contain a single front-facing object, it should have square aspect ratio, with <1024 pixel resolution. Check the examples under ./data. +# this will exports `_rgba.png`, `_depth.png`, and `_normal.png` to the directory containing the input image. +python preprocess_image.py .png +python preprocess_image.py .png --border_ratio 0.4 # increase border_ratio if the center object appears too large and results are unsatisfying. + +## zero123 train +# pass in the processed _rgba.png by --image and do NOT pass in --text to enable zero-1-to-3 backend. +python main.py -O --image _rgba.png --workspace trial_image --iters 5000 + +# if the image is not exactly front-view (elevation = 0), adjust default_polar (we use polar from 0 to 180 to represent elevation from 90 to -90) +python main.py -O --image _rgba.png --workspace trial_image --iters 5000 --default_polar 80 + +# by default we leverage monocular depth estimation to aid image-to-3d, but if you find the depth estimation inaccurate and harms results, turn it off by: +python main.py -O --image _rgba.png --workspace trial_image --iters 5000 --lambda_depth 0 + +python main.py -O --image _rgba.png --workspace trial_image_dmtet --dmtet --init_with trial_image/checkpoints/df.pth + +## zero123 with multiple images +python main.py -O --image_config config/.csv --workspace trial_image --iters 5000 + +## render images per batch (default 1) +python main.py -O --image_config config/.csv --workspace trial_image --iters 5000 --batch_size 4 + +# providing both --text and --image enables stable-diffusion backend (similar to make-it-3d) +python main.py -O --image hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial_image_text --iters 5000 + +python main.py -O --image hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial_image_text_dmtet --dmtet --init_with trial_image_text/checkpoints/df.pth + +## test / visualize +python main.py -O --image _rgba.png --workspace trial_image_dmtet --dmtet --test --save_mesh +python main.py -O --image _rgba.png --workspace trial_image_dmtet --dmtet --test --gui + +### Debugging + +# Can save guidance images for debugging purposes. These get saved in trial_hamburger/guidance. +# Warning: this slows down training considerably and consumes lots of disk space! +python main.py --text "a hamburger" --workspace trial_hamburger -O --vram_O --save_guidance --save_guidance_interval 5 # save every 5 steps +``` + +For example commands, check [`scripts`](./scripts). + +For advanced tips and other developing stuff, check [Advanced Tips](./assets/advanced.md). + +# Evalutation + +Reproduce the paper CLIP R-precision evaluation + +After the testing part in the usage, the validation set containing projection from different angle is generated. Test the R-precision between prompt and the image.(R=1) + +```bash +python r_precision.py --text "a snake is flying in the sky" --workspace snake_HQ --latest ep0100 --mode depth --clip clip-ViT-B-16 +``` + +# Acknowledgement + +This work is based on an increasing list of amazing research works and open-source projects, thanks a lot to all the authors for sharing! + +* [DreamFusion: Text-to-3D using 2D Diffusion](https://dreamfusion3d.github.io/) + ``` + @article{poole2022dreamfusion, + author = {Poole, Ben and Jain, Ajay and Barron, Jonathan T. and Mildenhall, Ben}, + title = {DreamFusion: Text-to-3D using 2D Diffusion}, + journal = {arXiv}, + year = {2022}, + } + ``` + +* [Magic3D: High-Resolution Text-to-3D Content Creation](https://research.nvidia.com/labs/dir/magic3d/) + ``` + @inproceedings{lin2023magic3d, + title={Magic3D: High-Resolution Text-to-3D Content Creation}, + author={Lin, Chen-Hsuan and Gao, Jun and Tang, Luming and Takikawa, Towaki and Zeng, Xiaohui and Huang, Xun and Kreis, Karsten and Fidler, Sanja and Liu, Ming-Yu and Lin, Tsung-Yi}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition ({CVPR})}, + year={2023} + } + ``` + +* [Zero-1-to-3: Zero-shot One Image to 3D Object](https://github.com/cvlab-columbia/zero123) + ``` + @misc{liu2023zero1to3, + title={Zero-1-to-3: Zero-shot One Image to 3D Object}, + author={Ruoshi Liu and Rundi Wu and Basile Van Hoorick and Pavel Tokmakov and Sergey Zakharov and Carl Vondrick}, + year={2023}, + eprint={2303.11328}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + ``` + +* [Perp-Neg: Re-imagine the Negative Prompt Algorithm: Transform 2D Diffusion into 3D, alleviate Janus problem and Beyond](https://perp-neg.github.io/) + ``` + @article{armandpour2023re, + title={Re-imagine the Negative Prompt Algorithm: Transform 2D Diffusion into 3D, alleviate Janus problem and Beyond}, + author={Armandpour, Mohammadreza and Zheng, Huangjie and Sadeghian, Ali and Sadeghian, Amir and Zhou, Mingyuan}, + journal={arXiv preprint arXiv:2304.04968}, + year={2023} + } + ``` + +* [RealFusion: 360° Reconstruction of Any Object from a Single Image](https://github.com/lukemelas/realfusion) + ``` + @inproceedings{melaskyriazi2023realfusion, + author = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea}, + title = {RealFusion: 360 Reconstruction of Any Object from a Single Image}, + booktitle={CVPR} + year = {2023}, + url = {https://arxiv.org/abs/2302.10663}, + } + ``` + +* [Fantasia3D: Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation](https://fantasia3d.github.io/) + ``` + @article{chen2023fantasia3d, + title={Fantasia3D: Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation}, + author={Rui Chen and Yongwei Chen and Ningxin Jiao and Kui Jia}, + journal={arXiv preprint arXiv:2303.13873}, + year={2023} + } + ``` + +* [Make-It-3D: High-Fidelity 3D Creation from A Single Image with Diffusion Prior](https://make-it-3d.github.io/) + ``` + @article{tang2023make, + title={Make-It-3D: High-Fidelity 3D Creation from A Single Image with Diffusion Prior}, + author={Tang, Junshu and Wang, Tengfei and Zhang, Bo and Zhang, Ting and Yi, Ran and Ma, Lizhuang and Chen, Dong}, + journal={arXiv preprint arXiv:2303.14184}, + year={2023} + } + ``` + +* [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and the [diffusers](https://github.com/huggingface/diffusers) library. + + ``` + @misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + + @misc{von-platen-etal-2022-diffusers, + author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf}, + title = {Diffusers: State-of-the-art diffusion models}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/diffusers}} + } + ``` + +* The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). + +* Puppy image from : https://www.pexels.com/photo/high-angle-photo-of-a-corgi-looking-upwards-2664417/ + +* Anya images from : https://www.goodsmile.info/en/product/13301/POP+UP+PARADE+Anya+Forger.html + +# Citation + +If you find this work useful, a citation will be appreciated via: +``` +@misc{stable-dreamfusion, + Author = {Jiaxiang Tang}, + Year = {2022}, + Note = {https://github.com/ashawkey/stable-dreamfusion}, + Title = {Stable-dreamfusion: Text-to-3D with Stable-diffusion} +} +``` diff --git a/stable-dreamfusion-3DPortrait/requirements.txt b/stable-dreamfusion-3DPortrait/requirements.txt new file mode 100644 index 0000000..ea12bcc --- /dev/null +++ b/stable-dreamfusion-3DPortrait/requirements.txt @@ -0,0 +1,56 @@ +tqdm +rich +ninja +numpy +pandas +scipy +scikit-learn +matplotlib +opencv-python +imageio +imageio-ffmpeg + +torch +torch-ema +einops +tensorboard +tensorboardX + +# for gui +dearpygui + +# for grid_tcnn +# git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + +# for stable-diffusion +huggingface_hub +diffusers >= 0.9.0 +accelerate +transformers + +# for dmtet and mesh export +xatlas +trimesh +PyMCubes +pymeshlab +git+https://github.com/NVlabs/nvdiffrast/ + +# for zero123 +carvekit-colab +omegaconf +pytorch-lightning +taming-transformers-rom1504 +kornia +git+https://github.com/openai/CLIP.git + +# for omnidata +gdown + +# for dpt +timm + +# for remote debugging +debugpy-run + +# for deepfloyd if +sentencepiece diff --git a/stable-dreamfusion-3DPortrait/rgb_finetune_command.py b/stable-dreamfusion-3DPortrait/rgb_finetune_command.py new file mode 100644 index 0000000..c2a2fe2 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/rgb_finetune_command.py @@ -0,0 +1,39 @@ +import os + +import glob +# import argparse +# +# parser = argparse.ArgumentParser() +# parser.add_argument('--trigrid_decoder_ckpt', type=str) +# parser.add_argument('--inversion_name', type=str) +# opt = parser.parse_args() +# trigrid_decoder_ckpt = opt.trigrid_decoder_ckpt +# inversion_name = opt.inversion_name + +count = 0 + +inversion_name = 'hierarchy_inversion_4000' +trigrid_decoder_ckpt ='F:\high_quality_3DPortraitGAN\exp/3DPortraitGAN-hierarchy\models/network-snapshot-004000_decoder.ckpt' +for prompt_file in glob.glob(f'F:/high_quality_3DPortraitGAN/exp/test_data/*/prompt.txt'): + + with open(prompt_file, 'r') as f: + prompt = f.read() + + prompt = prompt.replace('/n', '') + + dir_ = os.path.dirname(prompt_file) + name = dir_.split('/')[-1].split('\\')[-1] + #print(dir_.split('/'),dir_.split('/')[-1].split('\\')) + count_ = 0 + # if len(glob.glob(f'F:\high_quality_3DPortraitGAN\exp\stable-dreamfusion\output/2023-11-*-with-inversion-initialization-{name}_*')) > 0: + # continue + for inversion_trigrid in glob.glob(f'{dir_}/samples_new_crop/{inversion_name}/*/inversion_trigrid.pkl'): + name_ =name+ f'_{count_}' + cmd = f'python main_3DPortraitGAN.py --workspace output/2023-11-22-{name_}_{inversion_name} --save_guidance --backbone trigrid_heirarchy_aggregate --latent_iter_ratio 0 --t_range 0.02 0.4 --vram_O --w 128 --h 128 --H 512 --W 512 --iters 3000 --text "{prompt}" --hf_key F:\high_quality_3DPortraitGAN\exp\stable-dreamfusion\pretrained\SG161222Realistic_Vision_V5.1_noVAE --trigrid_path {inversion_trigrid} --trigrid_decoder_ckpt {trigrid_decoder_ckpt}' + print(cmd) + count_ += 1 + break + count += 1 + + + diff --git a/stable-dreamfusion-3DPortrait/scripts/install_ext.sh b/stable-dreamfusion-3DPortrait/scripts/install_ext.sh new file mode 100644 index 0000000..228190e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/install_ext.sh @@ -0,0 +1,4 @@ +pip install ./raymarching +pip install ./shencoder +pip install ./freqencoder +pip install ./gridencoder \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/res64.args b/stable-dreamfusion-3DPortrait/scripts/res64.args new file mode 100644 index 0000000..1e0d300 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/res64.args @@ -0,0 +1 @@ +-O --vram_O --w 64 --h 64 \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run.sh b/stable-dreamfusion-3DPortrait/scripts/run.sh new file mode 100644 index 0000000..8d4fed4 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run.sh @@ -0,0 +1,15 @@ +#! /bin/bash +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial_hamburger --iters 5000 +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a delicious hamburger" --workspace trial2_hamburger --dmtet --iters 5000 --init_with trial_hamburger/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial_stonehead --iters 5000 +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a highly detailed stone bust of Theodoros Kolokotronis" --workspace trial2_stonehead --dmtet --iters 5000 --init_with trial_stonehead/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial_astronaut --iters 5000 +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "an astronaut, full body" --workspace trial2_astronaut --dmtet --iters 5000 --init_with trial_astronaut/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial_squrrel_octopus --iters 5000 +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel-octopus hybrid" --workspace trial2_squrrel_octopus --dmtet --iters 5000 --init_with trial_squrrel_octopus/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000 +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run2.sh b/stable-dreamfusion-3DPortrait/scripts/run2.sh new file mode 100644 index 0000000..a958383 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run2.sh @@ -0,0 +1,10 @@ +#! /bin/bash + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial_shiba --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat" --workspace trial2_shiba --dmtet --iters 5000 --init_with trial_shiba/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial_banana --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a banana peeling itself" --workspace trial2_banana --dmtet --iters 5000 --init_with trial_banana/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial_capybara --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a capybara wearing a top hat, low poly" --workspace trial2_capybara --dmtet --iters 5000 --init_with trial_capybara/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run3.sh b/stable-dreamfusion-3DPortrait/scripts/run3.sh new file mode 100644 index 0000000..32d3457 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run3.sh @@ -0,0 +1,13 @@ +#! /bin/bash + +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial_ironman --iters 10000 +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "ironman, full body" --workspace trial2_ironman --dmtet --iters 5000 --init_with trial_ironman/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial_icecream --iters 10000 +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of an ice cream sundae" --workspace trial2_icecream --dmtet --iters 5000 --init_with trial_icecream/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial_bird --iters 10000 +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a DSLR photo of a kingfisher bird" --workspace trial2_bird --dmtet --iters 5000 --init_with trial_bird/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial_sushi --iters 10000 +CUDA_VISIBLE_DEVICES=7 python main.py -O --text "a car made of sushi" --workspace trial2_sushi --dmtet --iters 5000 --init_with trial_sushi/checkpoints/df.pth diff --git a/stable-dreamfusion-3DPortrait/scripts/run4.sh b/stable-dreamfusion-3DPortrait/scripts/run4.sh new file mode 100644 index 0000000..2308d0c --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run4.sh @@ -0,0 +1,13 @@ +#! /bin/bash + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial_rabbit2 --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a rabbit, animated movie character, high detail 3d model" --workspace trial2_rabbit2 --dmtet --iters 5000 --init_with trial_rabbit2/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial_corgi --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "a corgi dog, highly detailed 3d model" --workspace trial2_corgi --dmtet --iters 5000 --init_with trial_corgi/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial_cactus --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text " a small saguaro cactus planted in a clay pot" --workspace trial2_cactus --dmtet --iters 5000 --init_with trial_cactus/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial_pisa --iters 10000 +CUDA_VISIBLE_DEVICES=5 python main.py -O --text "the leaning tower of Pisa" --workspace trial2_pisa --dmtet --iters 5000 --init_with trial_pisa/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run5.sh b/stable-dreamfusion-3DPortrait/scripts/run5.sh new file mode 100644 index 0000000..fc2b7d1 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run5.sh @@ -0,0 +1,13 @@ +#! /bin/bash + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial_jay --iters 10000 +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Perched blue jay bird" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial_angle --iters 10000 +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "angel statue wings out" --workspace trial2_angle --dmtet --iters 5000 --init_with trial_angle/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial_devil --iters 10000 +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "devil statue" --workspace trial2_devil --dmtet --iters 5000 --init_with trial_devil/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial_einstein --iters 10000 +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "Einstein statue" --workspace trial2_einstein --dmtet --iters 5000 --init_with trial_einstein/checkpoints/df.pth diff --git a/stable-dreamfusion-3DPortrait/scripts/run6.sh b/stable-dreamfusion-3DPortrait/scripts/run6.sh new file mode 100644 index 0000000..ad2b946 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run6.sh @@ -0,0 +1,18 @@ +#! /bin/bash +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_rabbit_pancake --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_jay --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_fox --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial2_fox --dmtet --iters 5000 --init_with trial_fox/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_peacock --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial2_peacock --dmtet --iters 5000 --init_with trial_peacock/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial_metal_flower --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a flower made out of metal" --workspace trial2_metal_flower --dmtet --iters 5000 --init_with trial_metal_flower/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_chicken --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial2_chicken --dmtet --iters 5000 --init_with trial_chicken/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run_if.sh b/stable-dreamfusion-3DPortrait/scripts/run_if.sh new file mode 100644 index 0000000..07bb17f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_if.sh @@ -0,0 +1,18 @@ +#! /bin/bash +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a baby bunny sitting on top of a stack of pancakes" --workspace trial_if_rabbit_pancake --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a metal bunny sitting on top of a stack of chocolate cookies" --workspace trial_if2_rabbit_pancake --dmtet --iters 5000 --init_with trial_if_rabbit_pancake/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if_jay --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a blue jay standing on a large basket of rainbow macarons" --workspace trial_if2_jay --dmtet --iters 5000 --init_with trial_if_jay/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if_fox --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a fox taking a photograph using a DSLR" --workspace trial_if2_fox --dmtet --iters 5000 --init_with trial_if_fox/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if_peacock --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a DSLR photo of a peacock on a surfboard" --workspace trial_if2_peacock --dmtet --iters 5000 --init_with trial_if_peacock/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if_metal_flower --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a flower made out of metal" --workspace trial_if2_metal_flower --dmtet --iters 5000 --init_with trial_if_metal_flower/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if_chicken --iters 5000 --IF +CUDA_VISIBLE_DEVICES=2 python main.py -O --text "a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it" --workspace trial_if2_chicken --dmtet --iters 5000 --init_with trial_if_chicken/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run_if2.sh b/stable-dreamfusion-3DPortrait/scripts/run_if2.sh new file mode 100644 index 0000000..b363925 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_if2.sh @@ -0,0 +1,18 @@ +#! /bin/bash +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if_corgi --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a corgi taking a selfie" --workspace trial_if2_corgi --dmtet --iters 5000 --init_with trial_if_corgi/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if_ghost --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a ghost eating a hamburger" --workspace trial_if2_ghost --dmtet --iters 5000 --init_with trial_if_ghost/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if_motor --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of an origami motorcycle" --workspace trial_if2_motor --dmtet --iters 5000 --init_with trial_if_motor/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if_spaceshuttle --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a DSLR photo of a Space Shuttle" --workspace trial_if2_spaceshuttle --dmtet --iters 5000 --init_with trial_if_spaceshuttle/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if_palm --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a palm tree, low poly 3d model" --workspace trial_if2_palm --dmtet --iters 5000 --init_with trial_if_palm/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if_cat_mouse --iters 5000 --IF +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head" --workspace trial_if2_cat_mouse --dmtet --iters 5000 --init_with trial_if_cat_mouse/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run_if2_perpneg.sh b/stable-dreamfusion-3DPortrait/scripts/run_if2_perpneg.sh new file mode 100644 index 0000000..2261027 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_if2_perpneg.sh @@ -0,0 +1,18 @@ +#! /bin/bash +# To avoid the Janus problem caused by the diffusion model's front view bias, utilize the Perp-Neg algorithm. To maximize its benefits, +# increase the absolute value of "negative_w" for improved Janus problem mitigation. If you encounter flat faces or divergence, consider +# reducing the absolute value of "negative_w". The value of "negative_w" should vary for each prompt due to the diffusion model's varying +# bias towards generating front views for different objects. Vary the weights within the range of 0 to -4. +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a lion bust" --workspace trial_perpneg_if_lion --iters 5000 --IF --batch_size 1 --perpneg +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_lion/checkpoints/df.pth +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a marble lion head" --workspace trial_perpneg_if2_lion_nop --dmtet --iters 5000 --init_with trial_perpneg_if_lion/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a tiger cub" --workspace trial_perpneg_if_tiger --iters 5000 --IF --batch_size 1 --perpneg +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_tiger/checkpoints/df.pth +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "tiger" --workspace trial_perpneg_if2_tiger_nop --dmtet --iters 5000 --init_with trial_perpneg_if_tiger/checkpoints/df.pth + +# larger absolute value of negative_w is used for the following command because the defult negative weight of -2 is not enough to make the diffusion model to produce the views as desired +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "a shiba dog wearing sunglasses" --workspace trial_perpneg_if_shiba --iters 5000 --IF --batch_size 1 --perpneg --negative_w -3.0 +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_p --dmtet --iters 5000 --perpneg --negative_w -3.0 --init_with trial_perpneg_if_shiba/checkpoints/df.pth +CUDA_VISIBLE_DEVICES=3 python main.py -O --text "shiba wearing sunglasses" --workspace trial_perpneg_if2_shiba_nop --dmtet --iters 5000 --init_with trial_perpneg_if_shiba/checkpoints/df.pth + diff --git a/stable-dreamfusion-3DPortrait/scripts/run_image.sh b/stable-dreamfusion-3DPortrait/scripts/run_image.sh new file mode 100644 index 0000000..5c885e3 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_image.sh @@ -0,0 +1,25 @@ +# zero123 backend (single object, images like 3d model rendering) + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial_image_teddy --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial2_image_teddy --iters 5000 --dmtet --init_with trial_image_teddy/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial_image_catstatue --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial2_image_catstatue --iters 5000 --dmtet --init_with trial_image_catstatue/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial_image_firekeeper --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial2_image_firekeeper --iters 5000 --dmtet --init_with trial_image_firekeeper/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial_image_hamburger --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial2_image_hamburger --iters 5000 --dmtet --init_with trial_image_hamburger/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial_image_corgi --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial2_image_corgi --iters 5000 --dmtet --init_with trial_image_corgi/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial_image_cactus --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial2_image_cactus --iters 5000 --dmtet --init_with trial_image_cactus/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial_image_cake --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial2_image_cake --iters 5000 --dmtet --init_with trial_image_cake/checkpoints/df.pth + +# CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial_image_warrior --iters 5000 +# CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial2_image_warrior --iters 5000 --dmtet --init_with trial_image_warrior/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run_image_anya.sh b/stable-dreamfusion-3DPortrait/scripts/run_image_anya.sh new file mode 100644 index 0000000..e0d63ac --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_image_anya.sh @@ -0,0 +1,35 @@ +# Phase 1 - barely fits in A100 40GB. +# Conclusion: results in concave-ish face, no neck, excess hair in the back +CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage \ + --iters 10000 --save_guidance --save_guidance_interval 10 --ckpt scratch --batch_size 2 --test_interval 2 \ + --h 128 --w 128 --zero123_grad_scale None + +# Phase 2 - barely fits in A100 40GB. +# 20X smaller lambda_3d_normal_smooth, --known_view_interval 2, 3X LR +# Much higher jitter to increase disparity (and eliminate some of the flatness)... not too high either (to avoid cropping the face) +CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2 \ + --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \ + --iters 12500 --ckpt trial_anya_1_refimage/checkpoints/df_ep0100.pth --save_guidance --save_guidance_interval 1 \ + --h 256 --w 256 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \ + --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \ + --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 \ + --exp_start_iter 10000 --exp_end_iter 12500 + +# Phase 3 - increase resolution to 512 +# Disable textureless since they can cause catastrophic divergence +# Since radius range is inconsistent, increase it, and reduce the jitter to avoid excessively cropped renders. +# Learning rate may be set too high, since `--batch_size 1`. +CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2 \ + --text "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" \ + --iters 25000 --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2/checkpoints/df_ep0125.pth --save_guidance --save_guidance_interval 1 \ + --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \ + --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \ + --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ + --exp_start_iter 12500 --exp_end_iter 25000 + +# Generate 6 views +CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2/checkpoints/df_ep0250.pth --six_views + +# Phase 4 - untested, need to adjust +# CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage --iters 5000 --dmtet --init_with trial_anya_1_refimage/checkpoints/df.pth + diff --git a/stable-dreamfusion-3DPortrait/scripts/run_image_hard_examples.sh b/stable-dreamfusion-3DPortrait/scripts/run_image_hard_examples.sh new file mode 100644 index 0000000..4a8dd04 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_image_hard_examples.sh @@ -0,0 +1,11 @@ +bash scripts/run_image_procedure.sh 0 30 90 anya_front "A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" +bash scripts/run_image_procedure.sh 1 30 70 baby_phoenix_on_ice "A DSLR 3D photo of an adorable baby phoenix made in Swarowski crystal highly detailed intricate concept art 8K ( unreal engine 5 trending on Artstation )" +bash scripts/run_image_procedure.sh 2 30 90 bollywood_actress "A DSLR 3D photo of a beautiful bollywood indian actress, pretty eyes, full body shot composition, sunny outdoor, seen from far away ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" +bash scripts/run_image_procedure.sh 3 30 40 beach_house_1 "A DSLR 3D photo of a very beautiful small house on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" +bash scripts/run_image_procedure.sh 4 30 60 beach_house_2 "A DSLR 3D photo of a very beautiful high-tech small house with solar panels and wildflowers on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" +bash scripts/run_image_procedure.sh 5 30 90 mona_lisa "A DSLR 3D photo of a beautiful young woman dressed like Mona Lisa ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" +bash scripts/run_image_procedure.sh 6 30 80 futuristic_car "A DSLR 3D photo of a crazily futuristic electric car ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" +# the church ruins probably require a wider field of view... e.g. 90 degrees, maybe even more... so may not work with Zero123 etc. +bash scripts/run_image_procedure.sh 7 30 90 church_ruins "A DSLR 3D photo of the remains of an isolated old church ruin covered in ivy ( highly detailed intricate 8K unreal engine 5 trending on Artstation )" + +# young woman dressed like mona lisa \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/scripts/run_image_procedure.sh b/stable-dreamfusion-3DPortrait/scripts/run_image_procedure.sh new file mode 100644 index 0000000..722da3f --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_image_procedure.sh @@ -0,0 +1,71 @@ +# Perform a 2D-to-3D reconstruction, similar to the Anya case study: https://github.com/ashawkey/stable-dreamfusion/issues/263 +# Args: +# bash scripts/run_image_procedure.sh GPU_ID guidance_interval image_name "prompt" +# e.g.: +# bash scripts/run_image_procedure 1 30 baby_phoenix_on_ice "An adorable baby phoenix made in Swarowski crystal highly detailed intricated concept art 8K" +GPU_ID=$1 +GUIDANCE_INTERVAL=$2 +DEFAULT_POLAR=$3 +PREFIX=$4 +PROMPT=$5 +EPOCHS1=100 +EPOCHS2=200 +EPOCHS3=300 +IMAGE=data/$PREFIX.png +IMAGE_RGBA=data/${PREFIX}_rgba.png +WS_PH1=trial_$PREFIX-ph1 +WS_PH2=trial_$PREFIX-ph2 +WS_PH3=trial_$PREFIX-ph3 +CKPT1=$WS_PH1/checkpoints/df_ep0${EPOCHS1}.pth +CKPT2=$WS_PH2/checkpoints/df_ep0${EPOCHS2}.pth +CKPT3=$WS_PH3/checkpoints/df_ep0${EPOCHS3}.pth + +# Can uncomment to clear up trial folders. Be careful - mistakes could erase important work! +# rm -r $WS_PH1 $WS_PH2 $WS_PH3 + +# Preprocess +if [ ! -f $IMAGE_RGBA ] +then + python preprocess_image.py $IMAGE +fi + +if [ ! -f $CKPT1 ] +then + # Phase 1 - zero123-guidance + # WARNING: claforte: constantly runs out of VRAM with resolution of 128x128 and batch_size 2... no longer able to reproduce Anya result because of this... + # I added these to try to reduce mem usage, but this might degrade the quality... `--lambda_depth 0 --lambda_3d_normal_smooth 0` + # Remove: --ckpt scratch + CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH1 --default_polar $DEFAULT_POLAR \ + --iters ${EPOCHS1}00 --save_guidance --save_guidance_interval $GUIDANCE_INTERVAL --batch_size 1 --test_interval 2 \ + --h 96 --w 96 --zero123_grad_scale None --lambda_3d_normal_smooth 0 --dont_override_stuff \ + --fovy_range 20 20 --guidance_scale 5 +fi + +GUIDANCE_INTERVAL=7 +if [ ! -f $CKPT2 ] +then + # Phase 2 - SD-guidance at 256x256 + CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH2 \ + --text "${PROMPT}" --default_polar $DEFAULT_POLAR \ + --iters ${EPOCHS2}00 --ckpt $CKPT1 --save_guidance --save_guidance_interval 7 \ + --h 128 --w 128 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \ + --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \ + --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --progressive_view_init_ratio 0.05 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ + --exp_start_iter ${EPOCHS1}00 --exp_end_iter ${EPOCHS2}00 +fi + +if [ ! -f $CKPT3 ] +then + # # Phase 3 - increase resolution to 512 + CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH3 \ + --text "${PROMPT}" --default_polar $DEFAULT_POLAR \ + --iters ${EPOCHS3}00 --ckpt $CKPT2 --save_guidance --save_guidance_interval 7 \ + --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \ + --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \ + --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \ + --exp_start_iter ${EPOCHS2}00 --exp_end_iter ${EPOCHS3}00 +fi + +# Generate 6 views +CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --ckpt $CKPT3 --six_views + diff --git a/stable-dreamfusion-3DPortrait/scripts/run_image_text.sh b/stable-dreamfusion-3DPortrait/scripts/run_image_text.sh new file mode 100644 index 0000000..711dbf6 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_image_text.sh @@ -0,0 +1,13 @@ +# sd backend (realistic images) + +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial_imagetext_teddy --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text "a brown teddy bear sitting on a ground" --workspace trial2_imagetext_teddy --iters 10000 --dmtet --init_with trial_imagetext_teddy/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial_imagetext_corgi --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text "a corgi running" --workspace trial2_imagetext_corgi --iters 10000 --dmtet --init_with trial_imagetext_corgi/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial_imagetext_hamburger --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text "a DSLR photo of a delicious hamburger" --workspace trial2_imagetext_hamburger --iters 10000 --dmtet --init_with trial_imagetext_hamburger/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial_imagetext_cactus --iters 5000 +CUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text "a potted cactus plant" --workspace trial2_imagetext_cactus --iters 10000 --dmtet --init_with trial_imagetext_cactus/checkpoints/df.pth diff --git a/stable-dreamfusion-3DPortrait/scripts/run_images.sh b/stable-dreamfusion-3DPortrait/scripts/run_images.sh new file mode 100644 index 0000000..e41c981 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/scripts/run_images.sh @@ -0,0 +1,10 @@ +# zero123 backend (single object, images like 3d model rendering) + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial_images_corgi --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial2_images_corgi --iters 10000 --dmtet --init_with trial_images_corgi/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial_images_car --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial2_images_car --iters 10000 --dmtet --init_with trial_images_car/checkpoints/df.pth + +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial_images_anya --iters 5000 +CUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial2_images_anya --iters 10000 --dmtet --init_with trial_images_anya/checkpoints/df.pth \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/__init__.py b/stable-dreamfusion-3DPortrait/shencoder/__init__.py new file mode 100644 index 0000000..2b55c96 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/__init__.py @@ -0,0 +1 @@ +from .sphere_harmonics import SHEncoder \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/backend.py b/stable-dreamfusion-3DPortrait/shencoder/backend.py new file mode 100644 index 0000000..4971d5e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_sh_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/setup.py b/stable-dreamfusion-3DPortrait/shencoder/setup.py new file mode 100644 index 0000000..4633ebd --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='shencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_shencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/sphere_harmonics.py b/stable-dreamfusion-3DPortrait/shencoder/sphere_harmonics.py new file mode 100644 index 0000000..7bab24e --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/sphere_harmonics.py @@ -0,0 +1,87 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _shencoder as _backend +except ImportError: + from .backend import _backend + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/src/bindings.cpp b/stable-dreamfusion-3DPortrait/shencoder/src/bindings.cpp new file mode 100644 index 0000000..595b5b3 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "shencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); + m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.cu b/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.cu new file mode 100644 index 0000000..a92e4ab --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.cu @@ -0,0 +1,439 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__global__ void kernel_sh( + const scalar_t * __restrict__ inputs, + scalar_t * outputs, + uint32_t B, uint32_t D, uint32_t C, + scalar_t * dy_dx +) { + const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; + if (b >= B) return; + + const uint32_t C2 = C * C; + + // locate + inputs += b * D; + outputs += b * C2; + + scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; + + scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; + scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; + scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; + + auto write_sh = [&]() { + outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) + if (C <= 1) { return; } + outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) + outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) + outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) + if (C <= 2) { return; } + outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) + outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) + outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) + outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (C <= 3) { return; } + outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) + outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (C <= 5) { return; } + outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (C <= 7) { return; } + outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + }; + + write_sh(); + + if (dy_dx) { + scalar_t *dx = dy_dx + b * D * C2; + scalar_t *dy = dx + C2; + scalar_t *dz = dy + C2; + + auto write_sh_dx = [&]() { + dx[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dx[1] = 0.0f ; // 0 + dx[2] = 0.0f ; // 0 + dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + if (C <= 2) { return; } + dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) + dx[5] = 0.0f ; // 0 + dx[6] = 0.0f ; // 0 + dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + if (C <= 3) { return; } + dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) + dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) + dx[11] = 0.0f ; // 0 + dx[12] = 0.0f ; // 0 + dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + dx[19] = 0.0f ; // 0 + dx[20] = 0.0f ; // 0 + dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + dx[29] = 0.0f ; // 0 + dx[30] = 0.0f ; // 0 + dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[41] = 0.0f ; // 0 + dx[42] = 0.0f ; // 0 + dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[55] = 0.0f ; // 0 + dx[56] = 0.0f ; // 0 + dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + }; + + auto write_sh_dy = [&]() { + dy[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + dy[2] = 0.0f ; // 0 + dy[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dy[6] = 0.0f ; // 0 + dy[7] = 0.0f ; // 0 + dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + if (C <= 3) { return; } + dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dy[12] = 0.0f ; // 0 + dy[13] = 0.0f ; // 0 + dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) + dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) + if (C <= 4) { return; } + dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dy[20] = 0.0f ; // 0 + dy[21] = 0.0f ; // 0 + dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dy[30] = 0.0f ; // 0 + dy[31] = 0.0f ; // 0 + dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + if (C <= 6) { return; } + dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dy[42] = 0.0f ; // 0 + dy[43] = 0.0f ; // 0 + dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dy[56] = 0.0f ; // 0 + dy[57] = 0.0f ; // 0 + dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + }; + + auto write_sh_dz = [&]() { + dz[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dz[1] = 0.0f ; // 0 + dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) + dz[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dz[4] = 0.0f ; // 0 + dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) + dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) + dz[8] = 0.0f ; // 0 + if (C <= 3) { return; } + dz[9] = 0.0f ; // 0 + dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) + dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) + dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) + dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + dz[15] = 0.0f ; // 0 + if (C <= 4) { return; } + dz[16] = 0.0f ; // 0 + dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) + dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + dz[24] = 0.0f ; // 0 + if (C <= 5) { return; } + dz[25] = 0.0f ; // 0 + dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[35] = 0.0f ; // 0 + if (C <= 6) { return; } + dz[36] = 0.0f ; // 0 + dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[48] = 0.0f ; // 0 + if (C <= 7) { return; } + dz[49] = 0.0f ; // 0 + dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + dz[63] = 0.0f ; // 0 + }; + write_sh_dx(); + write_sh_dy(); + write_sh_dz(); + } +} + + +template +__global__ void kernel_sh_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t C, + const scalar_t * __restrict__ dy_dx, + scalar_t * grad_inputs +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t b = t / D; + if (b >= B) return; + + const uint32_t d = t - b * D; + const uint32_t C2 = C * C; + + // locate + grad += b * C2; + dy_dx += b * D * C2 + d * C2; + + for (int ch = 0; ch < C2; ch++) { + grad_inputs[t] += grad[ch] * dy_dx[ch]; + //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); + } + +} + +// inputs: [B, D], float, in [0, 1] +// outputs: [B, L * C], float +template +void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh<<>>(inputs, outputs, B, D, C, dy_dx); +} + + +template +void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); +} + + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { + sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); + })); +} + +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "sh_encode_backward_cuda", ([&] { + sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); + })); +} \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.h b/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.h new file mode 100644 index 0000000..f9e89fa --- /dev/null +++ b/stable-dreamfusion-3DPortrait/shencoder/src/shencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// inputs: [B, D], float, in [-1, 1] +// outputs: [B, F], float + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/__init__.py b/stable-dreamfusion-3DPortrait/taichi_modules/__init__.py new file mode 100644 index 0000000..3270636 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/__init__.py @@ -0,0 +1,5 @@ +from .ray_march import RayMarcherTaichi, raymarching_test +from .volume_train import VolumeRendererTaichi +from .intersection import RayAABBIntersector +from .volume_render_test import composite_test +from .utils import packbits \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/hash_encoder.py b/stable-dreamfusion-3DPortrait/taichi_modules/hash_encoder.py new file mode 100644 index 0000000..9a1b7a7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/hash_encoder.py @@ -0,0 +1,305 @@ +import numpy as np +import taichi as ti +import torch +from taichi.math import uvec3 +from torch.cuda.amp import custom_bwd, custom_fwd + +from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec, + ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec, + torch2ti_vec, torch_type) + +half2 = ti.types.vector(n=2, dtype=ti.f16) + + +@ti.kernel +def random_initialize(data: ti.types.ndarray()): + for I in ti.grouped(data): + data[I] = (ti.random() * 2.0 - 1.0) * 1e-4 + + +@ti.kernel +def ti_copy(data1: ti.template(), data2: ti.template()): + for I in ti.grouped(data1): + data1[I] = data2[I] + + +@ti.kernel +def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()): + for I in ti.grouped(data1): + data1[I] = data2[I] + + +@ti.kernel +def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()): + for I in ti.grouped(data1): + data1[I] = data2[I] + + +@ti.func +def fast_hash(pos_grid_local): + result = ti.uint32(0) + # primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761)) + primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861)) + for i in ti.static(range(3)): + result ^= ti.uint32(pos_grid_local[i]) * primes[i] + return result + + +@ti.func +def under_hash(pos_grid_local, resolution): + result = ti.uint32(0) + stride = ti.uint32(1) + for i in ti.static(range(3)): + result += ti.uint32(pos_grid_local[i] * stride) + stride *= resolution + return result + + +@ti.func +def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size): + hash_result = ti.uint32(0) + if indicator == 1: + hash_result = under_hash(pos_grid_local, resolution) + else: + hash_result = fast_hash(pos_grid_local) + + return hash_result % map_size + + +@ti.kernel +def hash_encode_kernel( + xyzs: ti.template(), table: ti.template(), + xyzs_embedding: ti.template(), hash_map_indicator: ti.template(), + hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32, + per_level_scale: ti.f32): + + # get hash table embedding + ti.loop_config(block_dim=16) + for i, level in ti.ndrange(B, 16): + xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]]) + + scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0 + resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1 + + offset = offsets[level] * 2 + + pos = xyz * scale + 0.5 + pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) + pos -= pos_grid_uint + + indicator = hash_map_indicator[level] + map_size = hash_map_sizes_field[level] + + local_feature_0 = 0.0 + local_feature_1 = 0.0 + + for idx in ti.static(range(8)): + w = 1. + pos_grid_local = uvec3(0) + + for d in ti.static(range(3)): + if (idx & (1 << d)) == 0: + pos_grid_local[d] = pos_grid_uint[d] + w *= 1 - pos[d] + else: + pos_grid_local[d] = pos_grid_uint[d] + 1 + w *= pos[d] + + index = grid_pos2hash_index(indicator, pos_grid_local, resolution, + map_size) + index_table = offset + index * 2 + index_table_int = ti.cast(index_table, ti.int32) + local_feature_0 += w * table[index_table_int] + local_feature_1 += w * table[index_table_int + 1] + + xyzs_embedding[i, level * 2] = local_feature_0 + xyzs_embedding[i, level * 2 + 1] = local_feature_1 + + +@ti.kernel +def hash_encode_kernel_half2( + xyzs: ti.template(), table: ti.template(), + xyzs_embedding: ti.template(), hash_map_indicator: ti.template(), + hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32, + per_level_scale: ti.f16): + + # get hash table embedding + ti.loop_config(block_dim=32) + for i, level in ti.ndrange(B, 16): + xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]]) + + scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0 + resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1 + + offset = offsets[level] + + pos = xyz * scale + 0.5 + pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) + pos -= pos_grid_uint + + indicator = hash_map_indicator[level] + map_size = hash_map_sizes_field[level] + + local_feature = half2(0.0) + for idx in ti.static(range(8)): + w = ti.f32(1.0) + pos_grid_local = uvec3(0) + + for d in ti.static(range(3)): + if (idx & (1 << d)) == 0: + pos_grid_local[d] = pos_grid_uint[d] + w *= 1 - pos[d] + else: + pos_grid_local[d] = pos_grid_uint[d] + 1 + w *= pos[d] + + index = grid_pos2hash_index(indicator, pos_grid_local, resolution, + map_size) + + index_table = offset + index + index_table_int = ti.cast(index_table, ti.int32) + + local_feature += w * table[index_table_int] + xyzs_embedding[i, level] = local_feature + + +class HashEncoderTaichi(torch.nn.Module): + + def __init__(self, + b=1.3195079565048218, + batch_size=8192, + data_type=data_type, + half2_opt=False): + super(HashEncoderTaichi, self).__init__() + + self.per_level_scale = b + if batch_size < 2048: + batch_size = 2048 + + # per_level_scale = 1.3195079565048218 + print("per_level_scale: ", b) + self.offsets = ti.field(ti.i32, shape=(16, )) + self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, )) + self.hash_map_indicator = ti.field(ti.i32, shape=(16, )) + base_res = 16 + max_params = 2**19 + offset_ = 0 + hash_map_sizes = [] + for i in range(16): + resolution = int( + np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) - + 1.0)) + 1 + params_in_level = resolution**3 + params_in_level = int(resolution** + 3) if params_in_level % 8 == 0 else int( + (params_in_level + 8 - 1) / 8) * 8 + params_in_level = min(max_params, params_in_level) + self.offsets[i] = offset_ + hash_map_sizes.append(params_in_level) + self.hash_map_indicator[ + i] = 1 if resolution**3 <= params_in_level else 0 + offset_ += params_in_level + print("offset_: ", offset_) + size = np.uint32(np.array(hash_map_sizes)) + self.hash_map_sizes_field.from_numpy(size) + + self.total_hash_size = offset_ * 2 + print("total_hash_size: ", self.total_hash_size) + + self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size, + dtype=torch_type), + requires_grad=True) + random_initialize(self.hash_table) + + if half2_opt: + assert self.total_hash_size % 2 == 0 + self.parameter_fields = half2.field(shape=(self.total_hash_size // + 2, ), + needs_grad=True) + self.output_fields = half2.field(shape=(batch_size * 1024, 16), + needs_grad=True) + + self.torch2ti = torch2ti_vec + self.ti2torch = ti2torch_vec + self.ti2torch_grad = ti2torch_grad_vec + self.torch2ti_grad = torch2ti_grad_vec + + self._hash_encode_kernel = hash_encode_kernel_half2 + else: + self.parameter_fields = ti.field(data_type, + shape=(self.total_hash_size, ), + needs_grad=True) + self.output_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, 32), + needs_grad=True) + self.torch2ti = torch2ti + self.ti2torch = ti2torch + self.ti2torch_grad = ti2torch_grad + self.torch2ti_grad = torch2ti_grad + + self._hash_encode_kernel = hash_encode_kernel + + self.input_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, 3), + needs_grad=True) + self.output_dim = 32 # the output dim: num levels (16) x level num (2) + self.register_buffer( + 'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type)) + self.register_buffer( + 'output_embedding', + torch.zeros(batch_size * 1024, 32, dtype=torch_type)) + + class _module_function(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch_type) + def forward(ctx, input_pos, params): + output_embedding = self.output_embedding[:input_pos. + shape[0]].contiguous( + ) + torch2ti(self.input_fields, input_pos.contiguous()) + self.torch2ti(self.parameter_fields, params.contiguous()) + + self._hash_encode_kernel( + self.input_fields, + self.parameter_fields, + self.output_fields, + self.hash_map_indicator, + self.hash_map_sizes_field, + self.offsets, + input_pos.shape[0], + self.per_level_scale, + ) + self.ti2torch(self.output_fields, output_embedding) + + return output_embedding + + @staticmethod + @custom_bwd + def backward(ctx, doutput): + + self.zero_grad() + + self.torch2ti_grad(self.output_fields, doutput.contiguous()) + self._hash_encode_kernel.grad( + self.input_fields, + self.parameter_fields, + self.output_fields, + self.hash_map_indicator, + self.hash_map_sizes_field, + self.offsets, + doutput.shape[0], + self.per_level_scale, + ) + self.ti2torch_grad(self.parameter_fields, + self.hash_grad.contiguous()) + return None, self.hash_grad + + self._module_function = _module_function + + def zero_grad(self): + self.parameter_fields.grad.fill(0.) + + def forward(self, positions, bound=1): + positions = (positions + bound) / (2 * bound) + return self._module_function.apply(positions, self.hash_table) diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/intersection.py b/stable-dreamfusion-3DPortrait/taichi_modules/intersection.py new file mode 100644 index 0000000..3879934 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/intersection.py @@ -0,0 +1,68 @@ +import taichi as ti +import torch +from taichi.math import vec3 +from torch.cuda.amp import custom_fwd + +from .utils import NEAR_DISTANCE + + +@ti.kernel +def simple_ray_aabb_intersec_taichi_forward( + hits_t: ti.types.ndarray(ndim=2), + rays_o: ti.types.ndarray(ndim=2), + rays_d: ti.types.ndarray(ndim=2), + centers: ti.types.ndarray(ndim=2), + half_sizes: ti.types.ndarray(ndim=2)): + + for r in ti.ndrange(hits_t.shape[0]): + ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]]) + ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]]) + inv_d = 1.0 / ray_d + + center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]]) + half_size = vec3( + [half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]]) + + t_min = (center - half_size - ray_o) * inv_d + t_max = (center + half_size - ray_o) * inv_d + + _t1 = ti.min(t_min, t_max) + _t2 = ti.max(t_min, t_max) + t1 = _t1.max() + t2 = _t2.min() + + if t2 > 0.0: + hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE) + hits_t[r, 0, 1] = t2 + + +class RayAABBIntersector(torch.autograd.Function): + """ + Computes the intersections of rays and axis-aligned voxels. + + Inputs: + rays_o: (N_rays, 3) ray origins + rays_d: (N_rays, 3) ray directions + centers: (N_voxels, 3) voxel centers + half_sizes: (N_voxels, 3) voxel half sizes + max_hits: maximum number of intersected voxels to keep for one ray + (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) + + Outputs: + hits_cnt: (N_rays) number of hits for each ray + (followings are from near to far) + hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) + hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, center, half_size, max_hits): + hits_t = (torch.zeros( + rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) - + 1).contiguous() + + simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center, + half_size) + + return None, hits_t, None diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/ray_march.py b/stable-dreamfusion-3DPortrait/taichi_modules/ray_march.py new file mode 100644 index 0000000..d159d03 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/ray_march.py @@ -0,0 +1,340 @@ +import taichi as ti +import torch +from taichi.math import vec3 +from torch.cuda.amp import custom_fwd + +from .utils import __morton3D, calc_dt, mip_from_dt, mip_from_pos + + +@ti.kernel +def raymarching_train(rays_o: ti.types.ndarray(ndim=2), + rays_d: ti.types.ndarray(ndim=2), + hits_t: ti.types.ndarray(ndim=2), + density_bitfield: ti.types.ndarray(ndim=1), + noise: ti.types.ndarray(ndim=1), + counter: ti.types.ndarray(ndim=1), + rays_a: ti.types.ndarray(ndim=2), + xyzs: ti.types.ndarray(ndim=2), + dirs: ti.types.ndarray(ndim=2), + deltas: ti.types.ndarray(ndim=1), + ts: ti.types.ndarray(ndim=1), cascades: int, + grid_size: int, scale: float, exp_step_factor: float, + max_samples: float): + + # ti.loop_config(block_dim=256) + for r in noise: + ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]) + ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]) + d_inv = 1.0 / ray_d + + t1, t2 = hits_t[r, 0], hits_t[r, 1] + + grid_size3 = grid_size**3 + grid_size_inv = 1.0 / grid_size + + if t1 >= 0: + dt = calc_dt(t1, exp_step_factor, grid_size, scale) + t1 += dt * noise[r] + + t = t1 + N_samples = 0 + + while (0 <= t) & (t < t2) & (N_samples < max_samples): + xyz = ray_o + t * ray_d + dt = calc_dt(t, exp_step_factor, grid_size, scale) + mip = ti.max(mip_from_pos(xyz, cascades), + mip_from_dt(dt, grid_size, cascades)) + + # mip_bound = 0.5 + # mip_bound = ti.min(ti.pow(2., mip - 1), scale) + mip_bound = scale + mip_bound_inv = 1 / mip_bound + + nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size, + 0.0, grid_size - 1.0) + # nxyz = ti.ceil(nxyz) + + idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) + occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) + # idx = __morton3D(ti.cast(nxyz, ti.uint32)) + # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32)) + + if occ: + t += dt + N_samples += 1 + else: + # t += dt + txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) * + grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv + + t_target = t + ti.max(0, txyz.min()) + t += calc_dt(t, exp_step_factor, grid_size, scale) + while t < t_target: + t += calc_dt(t, exp_step_factor, grid_size, scale) + + start_idx = ti.atomic_add(counter[0], N_samples) + ray_count = ti.atomic_add(counter[1], 1) + + rays_a[ray_count, 0] = r + rays_a[ray_count, 1] = start_idx + rays_a[ray_count, 2] = N_samples + + t = t1 + samples = 0 + + while (t < t2) & (samples < N_samples): + xyz = ray_o + t * ray_d + dt = calc_dt(t, exp_step_factor, grid_size, scale) + mip = ti.max(mip_from_pos(xyz, cascades), + mip_from_dt(dt, grid_size, cascades)) + + # mip_bound = 0.5 + # mip_bound = ti.min(ti.pow(2., mip - 1), scale) + mip_bound = scale + mip_bound_inv = 1 / mip_bound + + nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size, + 0.0, grid_size - 1.0) + # nxyz = ti.ceil(nxyz) + + idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) + occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) + # idx = __morton3D(ti.cast(nxyz, ti.uint32)) + # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32)) + + if occ: + s = start_idx + samples + xyzs[s, 0] = xyz[0] + xyzs[s, 1] = xyz[1] + xyzs[s, 2] = xyz[2] + dirs[s, 0] = ray_d[0] + dirs[s, 1] = ray_d[1] + dirs[s, 2] = ray_d[2] + ts[s] = t + deltas[s] = dt + t += dt + samples += 1 + else: + # t += dt + txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) * + grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv + + t_target = t + ti.max(0, txyz.min()) + t += calc_dt(t, exp_step_factor, grid_size, scale) + while t < t_target: + t += calc_dt(t, exp_step_factor, grid_size, scale) + + +@ti.kernel +def raymarching_train_backword(segments: ti.types.ndarray(ndim=2), + ts: ti.types.ndarray(ndim=1), + dL_drays_o: ti.types.ndarray(ndim=2), + dL_drays_d: ti.types.ndarray(ndim=2), + dL_dxyzs: ti.types.ndarray(ndim=2), + dL_ddirs: ti.types.ndarray(ndim=2)): + + for s in segments: + index = segments[s] + dxyz = dL_dxyzs[index] + ddir = dL_ddirs[index] + + dL_drays_o[s] = dxyz + dL_drays_d[s] = dxyz * ts[index] + ddir + + +class RayMarcherTaichi(torch.nn.Module): + + def __init__(self, batch_size=8192): + super(RayMarcherTaichi, self).__init__() + + self.register_buffer('rays_a', + torch.zeros(batch_size, 3, dtype=torch.int32)) + self.register_buffer( + 'xyzs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32)) + self.register_buffer( + 'dirs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32)) + self.register_buffer( + 'deltas', torch.zeros(batch_size * 1024, dtype=torch.float32)) + self.register_buffer( + 'ts', torch.zeros(batch_size * 1024, dtype=torch.float32)) + + # self.register_buffer('dL_drays_o', torch.zeros(batch_size, dtype=torch.float32)) + # self.register_buffer('dL_drays_d', torch.zeros(batch_size, dtype=torch.float32)) + + class _module_function(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, hits_t, density_bitfield, + cascades, scale, exp_step_factor, grid_size, + max_samples): + # noise to perturb the first sample of each ray + noise = torch.rand_like(rays_o[:, 0]) + counter = torch.zeros(2, + device=rays_o.device, + dtype=torch.int32) + + raymarching_train(\ + rays_o, rays_d, + hits_t.contiguous(), + density_bitfield, noise, counter, + self.rays_a.contiguous(), + self.xyzs.contiguous(), + self.dirs.contiguous(), + self.deltas.contiguous(), + self.ts.contiguous(), + cascades, grid_size, scale, + exp_step_factor, max_samples) + + # ti.sync() + + total_samples = counter[0] # total samples for all rays + # remove redundant output + xyzs = self.xyzs[:total_samples] + dirs = self.dirs[:total_samples] + deltas = self.deltas[:total_samples] + ts = self.ts[:total_samples] + + return self.rays_a, xyzs, dirs, deltas, ts, total_samples + + # @staticmethod + # @custom_bwd + # def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, dL_ddeltas, dL_dts, + # dL_dtotal_samples): + # rays_a, ts = ctx.saved_tensors + # # rays_a = rays_a.contiguous() + # ts = ts.contiguous() + # segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1] + rays_a[-1:, 2]]) + # dL_drays_o = torch.zeros_like(rays_a[:, 0]) + # dL_drays_d = torch.zeros_like(rays_a[:, 0]) + # raymarching_train_backword(segments.contiguous(), ts, dL_drays_o, + # dL_drays_d, dL_dxyzs, dL_ddirs) + # # ti.sync() + # # dL_drays_o = segment_csr(dL_dxyzs, segments) + # # dL_drays_d = \ + # # segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments) + + # return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None + + self._module_function = _module_function + + def forward(self, rays_o, rays_d, hits_t, density_bitfield, cascades, + scale, exp_step_factor, grid_size, max_samples): + return self._module_function.apply(rays_o, rays_d, hits_t, + density_bitfield, cascades, scale, + exp_step_factor, grid_size, + max_samples) + + +@ti.kernel +def raymarching_test_kernel( + rays_o: ti.types.ndarray(ndim=2), + rays_d: ti.types.ndarray(ndim=2), + hits_t: ti.types.ndarray(ndim=2), + alive_indices: ti.types.ndarray(ndim=1), + density_bitfield: ti.types.ndarray(ndim=1), + cascades: int, + grid_size: int, + scale: float, + exp_step_factor: float, + N_samples: int, + max_samples: int, + xyzs: ti.types.ndarray(ndim=2), + dirs: ti.types.ndarray(ndim=2), + deltas: ti.types.ndarray(ndim=1), + ts: ti.types.ndarray(ndim=1), + N_eff_samples: ti.types.ndarray(ndim=1), +): + + for n in alive_indices: + r = alive_indices[n] + grid_size3 = grid_size**3 + grid_size_inv = 1.0 / grid_size + + ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]) + ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]) + d_inv = 1.0 / ray_d + + t = hits_t[r, 0] + t2 = hits_t[r, 1] + + s = 0 + + while (0 <= t) & (t < t2) & (s < N_samples): + xyz = ray_o + t * ray_d + dt = calc_dt(t, exp_step_factor, grid_size, scale) + mip = ti.max(mip_from_pos(xyz, cascades), + mip_from_dt(dt, grid_size, cascades)) + + # mip_bound = 0.5 + # mip_bound = ti.min(ti.pow(2., mip - 1), scale) + mip_bound = scale + mip_bound_inv = 1 / mip_bound + + nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size, + 0.0, grid_size - 1.0) + # nxyz = ti.ceil(nxyz) + + idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) + occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) + + if occ: + xyzs[n, s, 0] = xyz[0] + xyzs[n, s, 1] = xyz[1] + xyzs[n, s, 2] = xyz[2] + dirs[n, s, 0] = ray_d[0] + dirs[n, s, 1] = ray_d[1] + dirs[n, s, 2] = ray_d[2] + ts[n, s] = t + deltas[n, s] = dt + t += dt + hits_t[r, 0] = t + s += 1 + + else: + txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) * + grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv + + t_target = t + ti.max(0, txyz.min()) + t += calc_dt(t, exp_step_factor, grid_size, scale) + while t < t_target: + t += calc_dt(t, exp_step_factor, grid_size, scale) + + N_eff_samples[n] = s + + +def raymarching_test(rays_o, rays_d, hits_t, alive_indices, density_bitfield, + cascades, scale, exp_step_factor, grid_size, max_samples, + N_samples): + + N_rays = alive_indices.size(0) + xyzs = torch.zeros(N_rays, + N_samples, + 3, + device=rays_o.device, + dtype=rays_o.dtype) + dirs = torch.zeros(N_rays, + N_samples, + 3, + device=rays_o.device, + dtype=rays_o.dtype) + deltas = torch.zeros(N_rays, + N_samples, + device=rays_o.device, + dtype=rays_o.dtype) + ts = torch.zeros(N_rays, + N_samples, + device=rays_o.device, + dtype=rays_o.dtype) + N_eff_samples = torch.zeros(N_rays, + device=rays_o.device, + dtype=torch.int32) + + raymarching_test_kernel(rays_o, rays_d, hits_t, alive_indices, + density_bitfield, cascades, grid_size, scale, + exp_step_factor, N_samples, max_samples, xyzs, + dirs, deltas, ts, N_eff_samples) + + # ti.sync() + + return xyzs, dirs, deltas, ts, N_eff_samples diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/utils.py b/stable-dreamfusion-3DPortrait/taichi_modules/utils.py new file mode 100644 index 0000000..02c2f2a --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/utils.py @@ -0,0 +1,224 @@ +import taichi as ti +import torch +from taichi.math import uvec3 + +taichi_block_size = 128 + +data_type = ti.f32 +torch_type = torch.float32 + +MAX_SAMPLES = 1024 +NEAR_DISTANCE = 0.01 +SQRT3 = 1.7320508075688772 +SQRT3_MAX_SAMPLES = SQRT3 / 1024 +SQRT3_2 = 1.7320508075688772 * 2 + + +@ti.func +def scalbn(x, exponent): + return x * ti.math.pow(2, exponent) + + +@ti.func +def calc_dt(t, exp_step_factor, grid_size, scale): + return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES, + SQRT3_2 * scale / grid_size) + + +@ti.func +def frexp_bit(x): + exponent = 0 + if x != 0.0: + # frac = ti.abs(x) + bits = ti.bit_cast(x, ti.u32) + exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127 + # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127 + bits &= ti.u32(0x7fffff) + bits |= ti.u32(0x3f800000) + frac = ti.bit_cast(bits, ti.f32) + if frac < 0.5: + exponent -= 1 + elif frac > 1.0: + exponent += 1 + return exponent + + +@ti.func +def mip_from_pos(xyz, cascades): + mx = ti.abs(xyz).max() + # _, exponent = _frexp(mx) + exponent = frexp_bit(ti.f32(mx)) + 1 + # frac, exponent = ti.frexp(ti.f32(mx)) + return ti.min(cascades - 1, ti.max(0, exponent)) + + +@ti.func +def mip_from_dt(dt, grid_size, cascades): + # _, exponent = _frexp(dt*grid_size) + exponent = frexp_bit(ti.f32(dt * grid_size)) + # frac, exponent = ti.frexp(ti.f32(dt*grid_size)) + return ti.min(cascades - 1, ti.max(0, exponent)) + + +@ti.func +def __expand_bits(v): + v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF) + v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F) + v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3) + v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249) + return v + + +@ti.func +def __morton3D(xyz): + xyz = __expand_bits(xyz) + return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2) + + +@ti.func +def __morton3D_invert(x): + x = x & (0x49249249) + x = (x | (x >> 2)) & ti.uint32(0xc30c30c3) + x = (x | (x >> 4)) & ti.uint32(0x0f00f00f) + x = (x | (x >> 8)) & ti.uint32(0xff0000ff) + x = (x | (x >> 16)) & ti.uint32(0x0000ffff) + return ti.int32(x) + + +@ti.kernel +def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1), + coords: ti.types.ndarray(ndim=2)): + for i in indices: + ind = ti.uint32(indices[i]) + coords[i, 0] = __morton3D_invert(ind >> 0) + coords[i, 1] = __morton3D_invert(ind >> 1) + coords[i, 2] = __morton3D_invert(ind >> 2) + + +def morton3D_invert(indices): + coords = torch.zeros(indices.size(0), + 3, + device=indices.device, + dtype=torch.int32) + morton3D_invert_kernel(indices.contiguous(), coords) + ti.sync() + return coords + + +@ti.kernel +def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2), + indices: ti.types.ndarray(ndim=1)): + for s in indices: + xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]]) + indices[s] = ti.cast(__morton3D(xyz), ti.int32) + + +def morton3D(coords1): + indices = torch.zeros(coords1.size(0), + device=coords1.device, + dtype=torch.int32) + morton3D_kernel(coords1.contiguous(), indices) + ti.sync() + return indices + + +@ti.kernel +def packbits(density_grid: ti.types.ndarray(ndim=1), + density_threshold: float, + density_bitfield: ti.types.ndarray(ndim=1)): + + for n in density_bitfield: + bits = ti.uint8(0) + + for i in ti.static(range(8)): + bits |= (ti.uint8(1) << i) if ( + density_grid[8 * n + i] > density_threshold) else ti.uint8(0) + + density_bitfield[n] = bits + + +@ti.kernel +def torch2ti(field: ti.template(), data: ti.types.ndarray()): + for I in ti.grouped(data): + field[I] = data[I] + + +@ti.kernel +def ti2torch(field: ti.template(), data: ti.types.ndarray()): + for I in ti.grouped(data): + data[I] = field[I] + + +@ti.kernel +def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()): + for I in ti.grouped(grad): + grad[I] = field.grad[I] + + +@ti.kernel +def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()): + for I in ti.grouped(grad): + field.grad[I] = grad[I] + + +@ti.kernel +def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()): + for I in range(data.shape[0] // 2): + field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]]) + + +@ti.kernel +def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()): + for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2): + data[i, j * 2] = field[i, j][0] + data[i, j * 2 + 1] = field[i, j][1] + + +@ti.kernel +def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()): + for I in range(grad.shape[0] // 2): + grad[I * 2] = field.grad[I][0] + grad[I * 2 + 1] = field.grad[I][1] + + +@ti.kernel +def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()): + for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2): + field.grad[i, j][0] = grad[i, j * 2] + field.grad[i, j][1] = grad[i, j * 2 + 1] + + +def extract_model_state_dict(ckpt_path, + model_name='model', + prefixes_to_ignore=[]): + checkpoint = torch.load(ckpt_path, map_location='cpu') + checkpoint_ = {} + if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint + checkpoint = checkpoint['state_dict'] + for k, v in checkpoint.items(): + if not k.startswith(model_name): + continue + k = k[len(model_name) + 1:] + for prefix in prefixes_to_ignore: + if k.startswith(prefix): + break + else: + checkpoint_[k] = v + return checkpoint_ + + +def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): + if not ckpt_path: + return + model_dict = model.state_dict() + checkpoint_ = extract_model_state_dict(ckpt_path, model_name, + prefixes_to_ignore) + model_dict.update(checkpoint_) + model.load_state_dict(model_dict) + +def depth2img(depth): + depth = (depth - depth.min()) / (depth.max() - depth.min()) + depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8), + cv2.COLORMAP_TURBO) + + return depth_img \ No newline at end of file diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/volume_render_test.py b/stable-dreamfusion-3DPortrait/taichi_modules/volume_render_test.py new file mode 100644 index 0000000..1bdefb7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/volume_render_test.py @@ -0,0 +1,48 @@ +import taichi as ti + + +@ti.kernel +def composite_test( + sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3), + deltas: ti.types.ndarray(ndim=2), ts: ti.types.ndarray(ndim=2), + hits_t: ti.types.ndarray(ndim=2), + alive_indices: ti.types.ndarray(ndim=1), T_threshold: float, + N_eff_samples: ti.types.ndarray(ndim=1), + opacity: ti.types.ndarray(ndim=1), + depth: ti.types.ndarray(ndim=1), rgb: ti.types.ndarray(ndim=2)): + + for n in alive_indices: + samples = N_eff_samples[n] + if samples == 0: + alive_indices[n] = -1 + else: + r = alive_indices[n] + + T = 1 - opacity[r] + + rgb_temp_0 = 0.0 + rgb_temp_1 = 0.0 + rgb_temp_2 = 0.0 + depth_temp = 0.0 + opacity_temp = 0.0 + + for s in range(samples): + a = 1.0 - ti.exp(-sigmas[n, s] * deltas[n, s]) + w = a * T + + rgb_temp_0 += w * rgbs[n, s, 0] + rgb_temp_1 += w * rgbs[n, s, 1] + rgb_temp_2 += w * rgbs[n, s, 2] + depth[r] += w * ts[n, s] + opacity[r] += w + T *= 1.0 - a + + if T <= T_threshold: + alive_indices[n] = -1 + break + + rgb[r, 0] += rgb_temp_0 + rgb[r, 1] += rgb_temp_1 + rgb[r, 2] += rgb_temp_2 + depth[r] += depth_temp + opacity[r] += opacity_temp diff --git a/stable-dreamfusion-3DPortrait/taichi_modules/volume_train.py b/stable-dreamfusion-3DPortrait/taichi_modules/volume_train.py new file mode 100644 index 0000000..7a52bfe --- /dev/null +++ b/stable-dreamfusion-3DPortrait/taichi_modules/volume_train.py @@ -0,0 +1,239 @@ +import taichi as ti +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + +from .utils import (data_type, ti2torch, ti2torch_grad, torch2ti, + torch2ti_grad, torch_type) + + +@ti.kernel +def composite_train_fw_array( + sigmas: ti.types.ndarray(), + rgbs: ti.types.ndarray(), + deltas: ti.types.ndarray(), + ts: ti.types.ndarray(), + rays_a: ti.types.ndarray(), + T_threshold: float, + total_samples: ti.types.ndarray(), + opacity: ti.types.ndarray(), + depth: ti.types.ndarray(), + rgb: ti.types.ndarray(), + ws: ti.types.ndarray(), +): + + for n in opacity: + ray_idx = rays_a[n, 0] + start_idx = rays_a[n, 1] + N_samples = rays_a[n, 2] + + T = 1.0 + samples = 0 + while samples < N_samples: + s = start_idx + samples + a = 1.0 - ti.exp(-sigmas[s] * deltas[s]) + w = a * T + + rgb[ray_idx, 0] += w * rgbs[s, 0] + rgb[ray_idx, 1] += w * rgbs[s, 1] + rgb[ray_idx, 2] += w * rgbs[s, 2] + depth[ray_idx] += w * ts[s] + opacity[ray_idx] += w + ws[s] = w + T *= 1.0 - a + + # if T T_threshold: + # s = start_idx + sample_ + a = 1.0 - ti.exp(-sigmas[s] * deltas[s]) + w = a * T_ + rgb[ray_idx, 0] += w * rgbs[s, 0] + rgb[ray_idx, 1] += w * rgbs[s, 1] + rgb[ray_idx, 2] += w * rgbs[s, 2] + depth[ray_idx] += w * ts[s] + opacity[ray_idx] += w + ws[s] = w + # T_ *= (1.0-a) + T[s + 1] = T_ * (1.0 - a) + # if T[s+1]>=T_threshold: + # samples += 1 + total_samples[ray_idx] += 1 + else: + T[s + 1] = 0.0 + + # total_samples[ray_idx] = N_samples + + +@ti.kernel +def check_value( + fields: ti.template(), + array: ti.types.ndarray(), + checker: ti.types.ndarray(), +): + for I in ti.grouped(array): + if fields[I] == array[I]: + checker[I] = 1 + + +class VolumeRendererTaichi(torch.nn.Module): + + def __init__(self, batch_size=8192, data_type=data_type): + super(VolumeRendererTaichi, self).__init__() + # samples level + self.sigmas_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, ), + needs_grad=True) + self.rgbs_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, 3), + needs_grad=True) + self.deltas_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, ), + needs_grad=True) + self.ts_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, ), + needs_grad=True) + self.ws_fields = ti.field(dtype=data_type, + shape=(batch_size * 1024, ), + needs_grad=True) + self.T = ti.field(dtype=data_type, + shape=(batch_size * 1024), + needs_grad=True) + + # rays level + self.rays_a_fields = ti.field(dtype=ti.i64, shape=(batch_size, 3)) + self.total_samples_fields = ti.field(dtype=ti.i64, + shape=(batch_size, )) + self.opacity_fields = ti.field(dtype=data_type, + shape=(batch_size, ), + needs_grad=True) + self.depth_fields = ti.field(dtype=data_type, + shape=(batch_size, ), + needs_grad=True) + self.rgb_fields = ti.field(dtype=data_type, + shape=(batch_size, 3), + needs_grad=True) + + # preallocate tensor + self.register_buffer('total_samples', + torch.zeros(batch_size, dtype=torch.int64)) + self.register_buffer('rgb', torch.zeros(batch_size, + 3, + dtype=torch_type)) + self.register_buffer('opacity', + torch.zeros(batch_size, dtype=torch_type)) + self.register_buffer('depth', torch.zeros(batch_size, + dtype=torch_type)) + self.register_buffer('ws', + torch.zeros(batch_size * 1024, dtype=torch_type)) + + self.register_buffer('sigma_grad', + torch.zeros(batch_size * 1024, dtype=torch_type)) + self.register_buffer( + 'rgb_grad', torch.zeros(batch_size * 1024, 3, dtype=torch_type)) + + class _module_function(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch_type) + def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold): + # If no output gradient is provided, no need to + # automatically materialize it as torch.zeros. + + ctx.T_threshold = T_threshold + ctx.samples_size = sigmas.shape[0] + + ws = self.ws[:sigmas.shape[0]] + + torch2ti(self.sigmas_fields, sigmas.contiguous()) + torch2ti(self.rgbs_fields, rgbs.contiguous()) + torch2ti(self.deltas_fields, deltas.contiguous()) + torch2ti(self.ts_fields, ts.contiguous()) + torch2ti(self.rays_a_fields, rays_a.contiguous()) + composite_train_fw(self.sigmas_fields, self.rgbs_fields, + self.deltas_fields, self.ts_fields, + self.rays_a_fields, T_threshold, self.T, + self.total_samples_fields, + self.opacity_fields, self.depth_fields, + self.rgb_fields, self.ws_fields) + ti2torch(self.total_samples_fields, self.total_samples) + ti2torch(self.opacity_fields, self.opacity) + ti2torch(self.depth_fields, self.depth) + ti2torch(self.rgb_fields, self.rgb) + + + return self.total_samples.sum( + ), self.opacity, self.depth, self.rgb, ws + + @staticmethod + @custom_bwd + def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, + dL_drgb, dL_dws): + + T_threshold = ctx.T_threshold + samples_size = ctx.samples_size + + sigma_grad = self.sigma_grad[:samples_size].contiguous() + rgb_grad = self.rgb_grad[:samples_size].contiguous() + + self.zero_grad() + + torch2ti_grad(self.opacity_fields, dL_dopacity.contiguous()) + torch2ti_grad(self.depth_fields, dL_ddepth.contiguous()) + torch2ti_grad(self.rgb_fields, dL_drgb.contiguous()) + torch2ti_grad(self.ws_fields, dL_dws.contiguous()) + composite_train_fw.grad(self.sigmas_fields, self.rgbs_fields, + self.deltas_fields, self.ts_fields, + self.rays_a_fields, T_threshold, + self.T, self.total_samples_fields, + self.opacity_fields, self.depth_fields, + self.rgb_fields, self.ws_fields) + ti2torch_grad(self.sigmas_fields, sigma_grad) + ti2torch_grad(self.rgbs_fields, rgb_grad) + + return sigma_grad, rgb_grad, None, None, None, None + + self._module_function = _module_function + + def zero_grad(self): + self.sigmas_fields.grad.fill(0.) + self.rgbs_fields.grad.fill(0.) + self.T.grad.fill(0.) + + + def forward(self, sigmas, rgbs, deltas, ts, rays_a, T_threshold): + return self._module_function.apply(sigmas, rgbs, deltas, ts, rays_a, + T_threshold) diff --git a/stable-dreamfusion-3DPortrait/tets/128_tets.npz b/stable-dreamfusion-3DPortrait/tets/128_tets.npz new file mode 100644 index 0000000..156b7fd Binary files /dev/null and b/stable-dreamfusion-3DPortrait/tets/128_tets.npz differ diff --git a/stable-dreamfusion-3DPortrait/tets/32_tets.npz b/stable-dreamfusion-3DPortrait/tets/32_tets.npz new file mode 100644 index 0000000..cffeeba Binary files /dev/null and b/stable-dreamfusion-3DPortrait/tets/32_tets.npz differ diff --git a/stable-dreamfusion-3DPortrait/tets/64_tets.npz b/stable-dreamfusion-3DPortrait/tets/64_tets.npz new file mode 100644 index 0000000..8515e35 Binary files /dev/null and b/stable-dreamfusion-3DPortrait/tets/64_tets.npz differ diff --git a/stable-dreamfusion-3DPortrait/tets/README.md b/stable-dreamfusion-3DPortrait/tets/README.md new file mode 100644 index 0000000..23719a7 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/tets/README.md @@ -0,0 +1,6 @@ +Place the tet grid files in this folder. +We provide a few example grids. See the main README.md for a download link. + +You can also generate your own grids using https://github.com/crawforddoran/quartet +Please see the `generate_tets.py` script for an example. + diff --git a/stable-dreamfusion-3DPortrait/tets/generate_tets.py b/stable-dreamfusion-3DPortrait/tets/generate_tets.py new file mode 100644 index 0000000..94c5241 --- /dev/null +++ b/stable-dreamfusion-3DPortrait/tets/generate_tets.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np + + +''' +This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, +to generate a tet grid +1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` +2) Run the function below to generate a file `cube_32_tet.tet` +''' + +def generate_tetrahedron_grid_file(res=32, root='..'): + frac = 1.0 / res + command = 'cd %s/quartet; ' % (root) + \ + './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res) + os.system(command) + + +''' +This code segment shows how to convert from a quartet .tet file to compressed npz file +''' +def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets.npz'): + + file1 = open(quartetfile, 'r') + header = file1.readline() + numvertices = int(header.split(" ")[1]) + numtets = int(header.split(" ")[2]) + print(numvertices, numtets) + + # load vertices + vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) + vertices = vertices - 0.5 + print(vertices.shape, vertices.min(), vertices.max()) + + # load indices + indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets) + print(indices.shape) + + np.savez_compressed(npzfile, vertices=vertices, indices=indices) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--res', type=int, default=32) + parser.add_argument('--root', type=str, default='..') + args = parser.parse_args() + + generate_tetrahedron_grid_file(res=args.res, root=args.root) + convert_from_quartet_to_npz(quartetfile=os.path.join(args.root, 'quartet', 'meshes', f'cube_{args.res}.000000_tet.tet'), npzfile=os.path.join('./tets', f'{args.res}_tets.npz')) \ No newline at end of file diff --git a/test_data/0/prompt.txt b/test_data/0/prompt.txt new file mode 100644 index 0000000..30c1ea2 --- /dev/null +++ b/test_data/0/prompt.txt @@ -0,0 +1 @@ +upper body photo, medieval, portrait photo of 25 y.o princess in blue dress, face, pale skin, intricate details, RAW, analog, sharp focus, 8k, HD, high quality, masterpiece \ No newline at end of file