Skip to content

Commit

Permalink
feat: support loading huggingface image dataset and convert image to …
Browse files Browse the repository at this point in the history
…PIL (lancedb#2684)
  • Loading branch information
eddyxu authored Aug 3, 2024
1 parent d141cc8 commit 712405e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
43 changes: 43 additions & 0 deletions python/python/lance/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

import io
from typing import TYPE_CHECKING, Any, Optional, Union

import pyarrow as pa

if TYPE_CHECKING:
import PIL.Image
import torch


class HuggingFaceConverter:
"""
Utility class for from PyArrow RecordBatch to Huggingface internal Type
"""

def __init__(self, ds_info: dict[str, Any]):
"""Create HuggingFaceConverter from Huggingface dataset info"""
self.ds_info = ds_info

def _to_pil_image(self, scalar: pa.StructScalar) -> "PIL.Image.Image":
import PIL.Image

row = scalar.as_py()
if row.get("bytes") is None:
return PIL.Image.open(row["path"])
return PIL.Image.open(io.BytesIO(row["bytes"]))

def to_pytorch(
self, col: str, array: pa.Array
) -> Optional[Union["torch.Tensor", list["PIL.Image.Image"]]]:
try:
feature = self.ds_info["info"]["features"][col]
except KeyError:
# Not covered in the features
return None
if feature["_type"] == "Image":
return [self._to_pil_image(x) for x in array]
raise NotImplementedError(
f"Conversion to {feature['_type']} is not implemented"
)
27 changes: 24 additions & 3 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# PEP-585. Can be removed after deprecating python 3.8 support.
from __future__ import annotations

import json
import math
import warnings
from pathlib import Path
Expand All @@ -30,15 +31,20 @@


def _to_tensor(
batch: pa.RecordBatch, *, uint64_as_int64: bool = True
batch: pa.RecordBatch,
*,
uint64_as_int64: bool = True,
hf_converter: Optional[dict] = None,
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
"""Convert a pyarrow RecordBatch to torch Tensor."""
ret = {}

for col in batch.schema.names:
arr: pa.Array = batch[col]
if pa.types.is_uint64(arr.type) and uint64_as_int64:
arr = arr.cast(pa.int64())

tensor: torch.Tensor = None
if (
pa.types.is_fixed_size_list(arr.type)
or isinstance(arr.type, pa.FixedShapeTensorType)
Expand All @@ -57,11 +63,15 @@ def _to_tensor(
or pa.types.is_boolean(arr.type)
):
tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False))
else:
elif hf_converter is not None:
tensor = hf_converter.to_pytorch(col, arr)

if tensor is None:
raise ValueError(
"Only support FixedSizeList<f16/f32/f64> or "
+ f"numeric values, got: {arr.type}"
)

del arr
ret[col] = tensor
if len(ret) == 1:
Expand Down Expand Up @@ -203,6 +213,7 @@ def __init__(
if to_tensor_fn is None:
to_tensor_fn = _to_tensor
self._to_tensor_fn = to_tensor_fn
self._hf_converter = None

# As Shared Dataset
self.rank = rank
Expand Down Expand Up @@ -230,6 +241,16 @@ def __init__(

self.sampler: Sampler = sampler

# Dataset with huggingface metadata
if (
dataset.schema.metadata is not None
and (hf_meta := dataset.schema.metadata.get(b"huggingface")) is not None
):
from ..hf import HuggingFaceConverter

hf_ds_info = json.loads(hf_meta)
self._hf_converter = HuggingFaceConverter(hf_ds_info)

self.cache = cache
self.cached_ds: Optional[CachedDataset] = None

Expand Down Expand Up @@ -266,6 +287,6 @@ def __iter__(self):

for batch in stream:
if self._to_tensor_fn is not None:
batch = self._to_tensor_fn(batch)
batch = self._to_tensor_fn(batch, hf_converter=self._hf_converter)
yield batch
del batch
24 changes: 24 additions & 0 deletions python/python/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from pathlib import Path

import lance
import lance.torch.data
import numpy as np
import pytest

datasets = pytest.importorskip("datasets")
pil = pytest.importorskip("PIL")


def test_write_hf_dataset(tmp_path: Path):
Expand All @@ -21,3 +24,24 @@ def test_write_hf_dataset(tmp_path: Path):
assert ds.count_rows() == 50

assert ds.schema == hf_ds.features.arrow_schema


def test_image_hf_dataset(tmp_path: Path):
ds = datasets.Dataset.from_dict(
{"i": [np.zeros(shape=(16, 16, 3), dtype=np.uint8)]},
features=datasets.Features({"i": datasets.Image()}),
)

ds = lance.write_dataset(ds, tmp_path)

dataset = lance.torch.data.LanceDataset(
ds,
columns=["i"],
batch_size=8,
)
batch = next(iter(dataset))
assert len(batch) == 1
assert all(
(isinstance(img, pil.Image.Image) and np.all(np.array(img) == 0))
for img in batch
)

0 comments on commit 712405e

Please sign in to comment.