Skip to content

About uint16 support #8359

@NicolasHug

Description

@NicolasHug

Pytorch 2.3 is introducing unsigned integer dtypes like uint16, uint32 and uint64 in pytorch/pytorch#116594.

Quoting Ed:

The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype.

I tried uint16 on some of the transforms and the following would work:

x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16)
transforms = T.Compose(
    [
        T.Pad(2),
        T.Resize(5),
        T.CenterCrop(3),
        # T.RandomHorizontalFlip(p=1),
        # T.ColorJitter(2, 2, 2, .1),
        T.ToDtype(torch.float32, scale=True),
    ]
)
transforms(x)

but stuff like flip or colorjitter won't work. In general, it's safe to assume that uint16 doesn't really work on eager.


What to do about F.to_tensor() and F.pil_to_tensor().

Up until 2.3, passing a unit16 PIL image (mode = "I;16") to those would produce:

  • to_tensor(): an int16 tensor as ouput for. This is completely wrong and a bug: the range of int16 is smaller than uint16, so the resulting tensor is incorrect and has tons of negative value (coming from overflow).
  • pil_to_tensor(): an error - this is OK.

Now with 2.3 (or more precisely with the nightlies/RC):

  • to_tensor(): still outputs an int16 tensor which is still incorrect
  • pil_to_tensor() outputs a uint16 tensor which is correct - but that tensor won't work with a lot of the transforms.

Proposed fix


Dirty notebook to play with:

Details
%
%load_ext autoreload
%autoreload 2
import numpy as np
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F
from PIL import Image
import torch

torch.__version__
#%%
x = torch.randint(100, (512, 512), dtype=torch.int16)
#%%
x_pil = F.to_pil_image(x)
x_pil.mode  # I;16
#%%
F.pil_to_tensor(x_pil).dtype  # torch.uint16
# %%
F.to_tensor(x_pil).dtype  # torch.int16
# %%
x = np.random.randint(0, np.iinfo(np.uint16).max, (10, 10), dtype=np.uint16)
x_pil = Image.fromarray(x, mode="I;16")
x_pil.mode  # I;16
# %%
F.pil_to_tensor(x_pil).dtype # torch.uint16
# %%
torch.testing.assert_close(torch.from_numpy(x)[None], F.pil_to_tensor(x_pil))

# %%
F.to_tensor(x_pil).dtype # torch.int16
# %%
torch.testing.assert_close(torch.from_numpy(x)[None].float(), F.to_tensor(x_pil).float())
# %%
x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16)
transforms = T.Compose(
    [
        T.Pad(2),
        T.Resize(5),
        T.CenterCrop(3),
        # T.RandomHorizontalFlip(p=1),
        # T.ColorJitter(2, 2, 2, .1),
        T.ToDtype(torch.float32, scale=True),
    ]
)
transforms(x)
#

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions