You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to equal np.save and np.load times from file but I'm pretty far. is there any way to optimize my class?
importpathlibimporttimefromtypingimportType, TypeVarimportmsgspecimportmsgspec.jsonimportmsgspec.msgpackimportnumpyasnpfromtyping_extensionsimportBuffer# Type variable for return type hintsT=TypeVar("T", bound="MsgpackModel")
# ------------------------------------------------------------------------------# Custom hooks for MessagePack# ------------------------------------------------------------------------------defmsgpack_enc_hook(obj: object) ->object:
ifisinstance(obj, np.ndarray):
# obj.data is memoryviewreturn [obj.data, str(obj.dtype), obj.shape]
elifisinstance(obj, pathlib.Path):
returnobj.as_posix()
raiseNotImplementedError(f"Object of type {type(obj)} is not supported in {msgpack_enc_hook.__name__}")
defmsgpack_dec_hook(expected_type: Type, obj: object) ->object:
ifexpected_typeispathlib.Path:
returnpathlib.Path(obj)
elifexpected_typeisnp.ndarray:
returnnp.frombuffer(obj[0], dtype=obj[1]).reshape(obj[2])
returnobj# ------------------------------------------------------------------------------# Base class for models with multiple serialization methods# ------------------------------------------------------------------------------classMsgpackModel(msgspec.Struct):
""" Abstract base class for models that support serialization/deserialization via MessagePack, JSON, YAML, and TOML. It has built-in support for: - np.ndarray (using binary extension in MessagePack; list conversion in JSON) - pathlib.Path (encoded as POSIX strings) This class is meant to be subclassed only. """def__post_init__(self):
# Prevent direct instantiation of the base class.iftype(self) isMsgpackModel:
raiseTypeError("MsgpackModel is an abstract base class; please subclass it.")
# -- MessagePack serialization --defto_msgpack(self) ->bytes:
""" Serialize the instance to MessagePack bytes. """encoder=msgspec.msgpack.Encoder(enc_hook=msgpack_enc_hook)
returnencoder.encode(self)
defto_msgpack_path(self, path: pathlib.Path|str) ->None:
""" Serialize the instance to MessagePack bytes and save to a file. """withopen(path, "wb") asf:
f.write(self.to_msgpack())
@classmethoddeffrom_msgpack(cls: Type[T], data: Buffer) ->T:
""" Deserialize MessagePack bytes into an instance of the calling class. """decoder=msgspec.msgpack.Decoder(cls, dec_hook=msgpack_dec_hook)
returndecoder.decode(data)
@classmethoddeffrom_msgpack_path(cls: Type[T], path: pathlib.Path|str) ->T:
""" Deserialize MessagePack bytes from a file into an instance of the calling class without unnecessary copies. """withopen(path, "rb") asf:
data=f.read()
returncls.from_msgpack(data)
if__name__=="__main__":
# Create a random 1000x1000 numpy array of floatsrandom_array=np.random.rand(20000, 20000)
# Time saving and loading using numpy's save and loadnp_save_path="random_array.npy"start_time=time.time()
np.save(np_save_path, random_array)
np_save_duration=time.time() -start_timeprint(f"NumPy save duration: {np_save_duration:.6f} seconds")
start_time=time.time()
loaded_array_np=np.load(np_save_path)
np_load_duration=time.time() -start_timeassertnp.array_equal(random_array, loaded_array_np)
print(f"NumPy load duration: {np_load_duration:.6f} seconds")
# Define a model with a single field for the numpy arrayclassArrayModel(MsgpackModel):
array: np.ndarraymodel_instance=ArrayModel(array=random_array)
# Time saving and loading using the custom modelstart_time=time.time()
model_bytes=model_instance.to_msgpack_path("random_array.msgpack")
model_save_duration=time.time() -start_timeprint(f"Model save duration: {model_save_duration:.6f} seconds")
start_time=time.time()
loaded_model_instance=ArrayModel.from_msgpack_path("random_array.msgpack")
model_load_duration=time.time() -start_timeassertnp.array_equal(random_array, loaded_model_instance.array)
print(f"Model load duration: {model_load_duration:.6f} seconds")
NumPy save duration: 1.980130 seconds
NumPy load duration: 0.932738 seconds
Model save duration: 2.645414 seconds
Model load duration: 1.598974 seconds
The text was updated successfully, but these errors were encountered:
Question
I'm trying to equal np.save and np.load times from file but I'm pretty far. is there any way to optimize my class?
The text was updated successfully, but these errors were encountered: