diff --git a/include/nanobind/stl/complex.h b/include/nanobind/stl/complex.h new file mode 100644 index 00000000..2dfad554 --- /dev/null +++ b/include/nanobind/stl/complex.h @@ -0,0 +1,52 @@ +/* + nanobind/stl/complex.h: type caster for std::complex<...> + + Copyright (c) 2023 Degottex Gilles + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template struct type_caster> { + + NB_TYPE_CASTER(std::complex, const_name("complex") ) + + bool from_python(handle src, uint8_t flags, + cleanup_list *cleanup) noexcept { + (void)flags; + (void)cleanup; + + PyObject* obj = src.ptr(); + + // TODO Faster way to get real part without string mapping? (and imag part below) + PyObject* obj_real = PyObject_GetAttrString(obj, "real"); + // TODO If T1==float32 and obj==numpy.float32, PyFloat_AsDouble implies 2 useless conversions + value.real(PyFloat_AsDouble(obj_real)); + PyObject* obj_imag = PyObject_GetAttrString(obj, "imag"); + value.imag(PyFloat_AsDouble(obj_imag)); + + return true; + } + + template + static handle from_cpp(T &&value, rv_policy policy, + cleanup_list *cleanup) noexcept { + (void)policy; + (void)cleanup; + + // There is no such float32 in Python, so always build as double. + // We could build a numpy.float32, though it would force dependency to numpy. + return PyComplex_FromDoubles(value.real(), value.imag()); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 9115654b..bb88a84c 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -12,6 +12,7 @@ #include #include #include +#include NB_MAKE_OPAQUE(std::vector>) @@ -422,4 +423,20 @@ NB_MODULE(test_stl_ext, m) { vec.flip(); return vec; }); + + + m.def("complex_value_float", [](const std::complex& x){ + return x; + }); + m.def("complex_value_double", [](const std::complex& x){ + return x; + }); + + m.def("complex_array_float", [](const std::vector>& x){ + return x; + }); + m.def("complex_array_double", [](const std::vector>& x){ + return x; + }); + } diff --git a/tests/test_stl.py b/tests/test_stl.py index 8da3d8e8..c7d47a22 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -3,6 +3,7 @@ import sys from common import collect, skip_on_pypy +import numpy as np @pytest.fixture def clean(): @@ -759,3 +760,47 @@ def test67_vector_bool(): bool_vector = [True, False, True, False] result = t.flip_vector_bool(bool_vector) assert result == [not x for x in bool_vector] + + +def test68_complex_value(): + + # double: 64bits + assert t.complex_value_double(1.0) == 1.0 + assert t.complex_value_double(1.0j) == 1.0j + assert t.complex_value_double(0.0) == 0.0 + assert t.complex_value_double(0.0j) == 0.0j + assert t.complex_value_double(2.7+3.2j) == (2.7+3.2j) + assert t.complex_value_double(2.7-3.2j) == (2.7-3.2j) + assert t.complex_value_double(np.complex128(2.7-3.2j)) == (2.7-3.2j) + assert t.complex_value_double(np.complex64(2.7-3.2j)) == (2.700000047683716-3.200000047683716j) # written as double, converted to complex of float32 (difference introduced), converted to complex of float64 (no difference introduced) -> difference introduced + + # float: 32bits + assert t.complex_value_float(1.0) == 1.0 + assert t.complex_value_float(1.0j) == 1.0j + assert t.complex_value_float(0.0) == 0.0 + assert t.complex_value_float(0.0j) == 0.0j + # written as double, converted to complex of float32 (difference introduced) -> difference introduced + assert t.complex_value_float(2.7+3.2j) == (2.700000047683716+3.200000047683716j) + # same as above + assert t.complex_value_float(2.7-3.2j) == (2.700000047683716-3.200000047683716j) + + # written as double, converted to complex of float64 (no difference introduced), converted to complex of float32 (difference introduced) -> difference introduced + assert t.complex_value_float(np.complex128(2.7-3.2j)) == (2.700000047683716-3.200000047683716j) + + # written as double, converted to complex of float32 (difference introduced), converted to complex of float32 (no difference introduced) -> difference introduced + assert t.complex_value_float(np.complex64(2.7-3.2j)) == (2.700000047683716-3.200000047683716j) + +def test69_complex_array(): + + # double: 64bits + assert t.complex_array_double([2.7-3.2j, -1j, +3.1415]) == [(2.7-3.2j), (-0-1j), (+3.1415)] + assert t.complex_array_double(np.array([2.7-3.2j, -1j, +3.1415])) == [(2.7-3.2j), (-0-1j), (+3.1415)] + assert t.complex_array_double(np.array([2.7-3.2j, -1j, +3.1415],dtype=np.complex128)) == [(2.7-3.2j), (-0-1j), (+3.1415)] + assert t.complex_array_double(np.array([2.7-3.2j, -1j, +3.1415],dtype=np.complex64)) == [(2.700000047683716-3.200000047683716j), (-0-1j), (3.1414999961853027+0j)] + + # float: 32bits + # Always go through double to float conversion bcs python's complex numbers are always written as double (there is no such thing as 1.0f notation) + assert t.complex_array_float([2.7-3.2j, -1j, +3.1415]) == [(2.700000047683716-3.200000047683716j), (-0-1j), (3.1414999961853027+0j)] + assert t.complex_array_float(np.array([2.7-3.2j, -1j, +3.1415])) == [(2.700000047683716-3.200000047683716j), (-0-1j), (3.1414999961853027+0j)] + assert t.complex_array_float(np.array([2.7-3.2j, -1j, +3.1415],dtype=np.complex128)) == [(2.700000047683716-3.200000047683716j), (-0-1j), (3.1414999961853027+0j)] + assert t.complex_array_float(np.array([2.7-3.2j, -1j, +3.1415],dtype=np.complex64)) == [(2.700000047683716-3.200000047683716j), (-0-1j), (3.1414999961853027+0j)]