Skip to content

Commit

Permalink
Add type caster for std::complex<T> (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
gillesdegottex authored and wjakob committed Sep 29, 2023
1 parent 901ea21 commit dcbed4f
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 1 deletion.
67 changes: 67 additions & 0 deletions include/nanobind/stl/complex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
nanobind/stl/complex.h: type caster for std::complex<...>
Copyright (c) 2023 Degottex Gilles and Wenzel Jakob
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include <nanobind/nanobind.h>
#include <complex>

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename T> struct type_caster<std::complex<T>> {
NB_TYPE_CASTER(std::complex<T>, const_name("complex") )

template <bool Recursive = true>
bool from_python(handle src, uint8_t flags,
cleanup_list *cleanup) noexcept {
(void) flags;
(void) cleanup;

if (PyComplex_Check(src.ptr())) {
value = std::complex<T>(
(T) PyComplex_RealAsDouble(src.ptr()),
(T) PyComplex_ImagAsDouble(src.ptr())
);
return true;
}

if (Recursive && !PyFloat_CheckExact(src.ptr()) &&
!PyLong_CheckExact(src.ptr()) &&
PyObject_HasAttrString(src.ptr(), "imag")) {
try {
object tmp = handle(&PyComplex_Type)(src);
return from_python<false>(tmp, flags, cleanup);
} catch (...) {
return false;
}
}

make_caster<T> caster;
if (caster.from_python(src, flags, cleanup)) {
value = std::complex<T>(caster.operator cast_t<T>());
return true;
}

return true;
}

template <typename T2>
static handle from_cpp(T2 &&value, rv_policy policy,
cleanup_list *cleanup) noexcept {
(void) policy;
(void) cleanup;

return PyComplex_FromDoubles((double) value.real(),
(double) value.imag());
}
};

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)
17 changes: 17 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <nanobind/stl/unordered_set.h>
#include <nanobind/stl/set.h>
#include <nanobind/stl/filesystem.h>
#include <nanobind/stl/complex.h>

NB_MAKE_OPAQUE(std::vector<float, std::allocator<float>>)

Expand Down Expand Up @@ -422,4 +423,20 @@ NB_MODULE(test_stl_ext, m) {
vec.flip();
return vec;
});


m.def("complex_value_float", [](const std::complex<float>& x){
return x;
});
m.def("complex_value_double", [](const std::complex<double>& x){
return x;
});

m.def("complex_array_float", [](const std::vector<std::complex<float>>& x){
return x;
});
m.def("complex_array_double", [](const std::vector<std::complex<double>>& x){
return x;
});

}
64 changes: 63 additions & 1 deletion tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
from common import collect, skip_on_pypy


@pytest.fixture
def clean():
collect()
Expand Down Expand Up @@ -758,3 +757,66 @@ def test67_vector_bool():
bool_vector = [True, False, True, False]
result = t.flip_vector_bool(bool_vector)
assert result == [not x for x in bool_vector]


def test68_complex_value():
# double: 64bits
assert t.complex_value_double(1.0) == 1.0
assert t.complex_value_double(1.0j) == 1.0j
assert t.complex_value_double(0.0) == 0.0
assert t.complex_value_double(0.0j) == 0.0j
assert t.complex_value_double(0) == 0
assert t.complex_value_float(1.0) == 1.0
assert t.complex_value_float(1.0j) == 1.0j
assert t.complex_value_float(0.0) == 0.0
assert t.complex_value_float(0.0j) == 0.0j
assert t.complex_value_float(0) == 0

val_64 = 2.7-3.2j
val_32 = 2.700000047683716-3.200000047683716j
assert val_64 != val_32

assert t.complex_value_float(val_32) == val_32
assert t.complex_value_float(val_64) == val_32
assert t.complex_value_double(val_32) == val_32
assert t.complex_value_double(val_64) == val_64

try:
import numpy as np
assert t.complex_value_float(np.complex64(val_32)) == val_32
assert t.complex_value_float(np.complex64(val_64)) == val_32
assert t.complex_value_double(np.complex64(val_32)) == val_32
assert t.complex_value_double(np.complex64(val_64)) == val_32
assert t.complex_value_float(np.complex128(val_32)) == val_32
assert t.complex_value_float(np.complex128(val_64)) == val_32
assert t.complex_value_double(np.complex128(val_32)) == val_32
assert t.complex_value_double(np.complex128(val_64)) == val_64
except ImportError:
pass

def test69_complex_array():
val1_64 = 2.7-3.2j
val1_32 = 2.700000047683716-3.200000047683716j
val2_64 = 3.1415
val2_32 = 3.1414999961853027+0j

# test 64 bit casts
assert t.complex_array_double([val1_64, -1j, val2_64]) == [val1_64, -0-1j, val2_64]

# test 32 bit casts
assert t.complex_array_float([val1_64, -1j, val2_64]) == [val1_32, (-0-1j), val2_32]

try:
import numpy as np

# test 64 bit casts
assert t.complex_array_double(np.array([val1_64, -1j, val2_64])) == [val1_64, -0-1j, val2_64]
assert t.complex_array_double(np.array([val1_64, -1j, val2_64],dtype=np.complex128)) == [val1_64, -0-1j, val2_64]
assert t.complex_array_double(np.array([val1_64, -1j, val2_64],dtype=np.complex64)) == [val1_32, -0-1j, val2_32]

# test 32 bit casts
assert t.complex_array_float(np.array([val1_64, -1j, val2_64])) == [val1_32, (-0-1j), val2_32]
assert t.complex_array_float(np.array([val1_64, -1j, val2_64],dtype=np.complex128)) == [val1_32, (-0-1j), val2_32]
assert t.complex_array_float(np.array([val1_64, -1j, val2_64],dtype=np.complex64)) == [val1_32, (-0-1j), val2_32]
except ImportError:
pass

0 comments on commit dcbed4f

Please sign in to comment.