Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizing extended numpy encoder decoder #811

Open
YoniChechik opened this issue Feb 4, 2025 · 0 comments
Open

optimizing extended numpy encoder decoder #811

YoniChechik opened this issue Feb 4, 2025 · 0 comments

Comments

@YoniChechik
Copy link

Question

I'm trying to equal np.save and np.load times from file but I'm pretty far. is there any way to optimize my class?

import pathlib
import time
from typing import Type, TypeVar

import msgspec
import msgspec.json
import msgspec.msgpack
import numpy as np
from typing_extensions import Buffer

# Type variable for return type hints
T = TypeVar("T", bound="MsgpackModel")

# ------------------------------------------------------------------------------
# Custom hooks for MessagePack
# ------------------------------------------------------------------------------


def msgpack_enc_hook(obj: object) -> object:
    if isinstance(obj, np.ndarray):
        # obj.data is memoryview
        return [obj.data, str(obj.dtype), obj.shape]
    elif isinstance(obj, pathlib.Path):
        return obj.as_posix()
    raise NotImplementedError(f"Object of type {type(obj)} is not supported in {msgpack_enc_hook.__name__}")


def msgpack_dec_hook(expected_type: Type, obj: object) -> object:
    if expected_type is pathlib.Path:
        return pathlib.Path(obj)
    elif expected_type is np.ndarray:
        return np.frombuffer(obj[0], dtype=obj[1]).reshape(obj[2])
    return obj


# ------------------------------------------------------------------------------
# Base class for models with multiple serialization methods
# ------------------------------------------------------------------------------


class MsgpackModel(msgspec.Struct):
    """
    Abstract base class for models that support serialization/deserialization
    via MessagePack, JSON, YAML, and TOML. It has built-in support for:
      - np.ndarray (using binary extension in MessagePack; list conversion in JSON)
      - pathlib.Path (encoded as POSIX strings)

    This class is meant to be subclassed only.
    """

    def __post_init__(self):
        # Prevent direct instantiation of the base class.
        if type(self) is MsgpackModel:
            raise TypeError("MsgpackModel is an abstract base class; please subclass it.")

    # -- MessagePack serialization --

    def to_msgpack(self) -> bytes:
        """
        Serialize the instance to MessagePack bytes.
        """
        encoder = msgspec.msgpack.Encoder(enc_hook=msgpack_enc_hook)
        return encoder.encode(self)

    def to_msgpack_path(self, path: pathlib.Path | str) -> None:
        """
        Serialize the instance to MessagePack bytes and save to a file.
        """

        with open(path, "wb") as f:
            f.write(self.to_msgpack())

    @classmethod
    def from_msgpack(cls: Type[T], data: Buffer) -> T:
        """
        Deserialize MessagePack bytes into an instance of the calling class.
        """
        decoder = msgspec.msgpack.Decoder(cls, dec_hook=msgpack_dec_hook)
        return decoder.decode(data)

    @classmethod
    def from_msgpack_path(cls: Type[T], path: pathlib.Path | str) -> T:
        """
        Deserialize MessagePack bytes from a file into an instance of the calling class
        without unnecessary copies.
        """

        with open(path, "rb") as f:
            data = f.read()
        return cls.from_msgpack(data)


if __name__ == "__main__":
    # Create a random 1000x1000 numpy array of floats
    random_array = np.random.rand(20000, 20000)

    # Time saving and loading using numpy's save and load
    np_save_path = "random_array.npy"
    start_time = time.time()
    np.save(np_save_path, random_array)
    np_save_duration = time.time() - start_time
    print(f"NumPy save duration: {np_save_duration:.6f} seconds")

    start_time = time.time()
    loaded_array_np = np.load(np_save_path)
    np_load_duration = time.time() - start_time

    assert np.array_equal(random_array, loaded_array_np)
    print(f"NumPy load duration: {np_load_duration:.6f} seconds")

    # Define a model with a single field for the numpy array
    class ArrayModel(MsgpackModel):
        array: np.ndarray

    model_instance = ArrayModel(array=random_array)

    # Time saving and loading using the custom model
    start_time = time.time()
    model_bytes = model_instance.to_msgpack_path("random_array.msgpack")
    model_save_duration = time.time() - start_time
    print(f"Model save duration: {model_save_duration:.6f} seconds")

    start_time = time.time()
    loaded_model_instance = ArrayModel.from_msgpack_path("random_array.msgpack")
    model_load_duration = time.time() - start_time

    assert np.array_equal(random_array, loaded_model_instance.array)
    print(f"Model load duration: {model_load_duration:.6f} seconds")
NumPy save duration: 1.980130 seconds
NumPy load duration: 0.932738 seconds
Model save duration: 2.645414 seconds
Model load duration: 1.598974 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant