Skip to content

Commit

Permalink
free-threading: Locking for internal data structures
Browse files Browse the repository at this point in the history
This commit enables free-threaded extension builds on Python 3.13+,
which involves the following changes:

- nanobind must notify Python that an extension supports free-threading.

- All internal data structures must be protected from concurrent
  modification. The approach taken varies with respect to the specific
  data structure, and a long comment in ``nb_internals.h`` explains the
  design decisions all of the changes. In general, the implementation
  avoids centralized locks as much as possible to improve scalability.

- Adopting safe versions of certain operations where needed, e.g.
  ``PyList_GetItemRef()``.

- Switching non-object allocation from ``PyObject_Allo()`` to
  ``PyMem_Alloc()``.
  • Loading branch information
wjakob committed Sep 20, 2024
1 parent fd972b6 commit 6174eb8
Show file tree
Hide file tree
Showing 13 changed files with 588 additions and 217 deletions.
21 changes: 15 additions & 6 deletions include/nanobind/nb_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,30 @@ struct num_item {
};

struct num_item_list {
static constexpr bool cache_dec_ref = false;
#if defined(Py_GIL_DISABLED)
static constexpr bool cache_dec_ref = true;
#else
static constexpr bool cache_dec_ref = false;
#endif

using key_type = Py_ssize_t;

NB_INLINE static void get(PyObject *obj, Py_ssize_t index, PyObject **cache) {
*cache = NB_LIST_GET_ITEM(obj, index);
#if defined(Py_GIL_DISABLED)
*cache = PyList_GetItemRef(obj, index);
#else
*cache = NB_LIST_GET_ITEM(obj, index);
#endif
}

NB_INLINE static void set(PyObject *obj, Py_ssize_t index, PyObject *v) {
#if !defined(Py_LIMITED_API)
// Handle differences between PyList_SetItem and PyList_SET_ITEM
#if defined(Py_LIMITED_API) || defined(NB_FREE_THREADED)
Py_INCREF(v);
PyList_SetItem(obj, index, v);
#else
PyObject *old = NB_LIST_GET_ITEM(obj, index);
#endif
Py_INCREF(v);
NB_LIST_SET_ITEM(obj, index, v);
#if !defined(Py_LIMITED_API)
Py_DECREF(old);
#endif
}
Expand Down
2 changes: 2 additions & 0 deletions include/nanobind/stl/detail/nb_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ template <typename Dict, typename Key, typename Val> struct dict_caster {
return false;
}

// 'items' is safe to access without locking and reference counting, it
// is unique to this thread
Py_ssize_t size = NB_LIST_GET_SIZE(items);
bool success = size >= 0;

Expand Down
31 changes: 20 additions & 11 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ PyObject *module_new(const char *name, PyModuleDef *def) noexcept {
def->m_name = name;
def->m_size = -1;
PyObject *m = PyModule_Create(def);

#ifdef NB_FREE_THREADED
PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED);
#endif

check(m, "nanobind::detail::module_new(): allocation failed!");
return m;
}
Expand Down Expand Up @@ -669,13 +674,15 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
still trigger a segfault if dereferenced. */
if (size == 0)
result = (PyObject **) 1;
# if !defined(NB_FREE_THREADED) // Require immutable holder in free-threaded mode
} else if (PyList_CheckExact(seq)) {
size = (size_t) PyList_GET_SIZE(seq);
result = ((PyListObject *) seq)->ob_item;
if (size == 0) // ditto
result = (PyObject **) 1;
# endif
} else if (PySequence_Check(seq)) {
temp = PySequence_Fast(seq, "");
temp = PySequence_Tuple(seq);

if (temp)
result = seq_get(temp, &size, temp_out);
Expand All @@ -689,8 +696,8 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
Py_ssize_t size_seq = PySequence_Length(seq);

if (size_seq >= 0) {
result = (PyObject **) PyObject_Malloc(sizeof(PyObject *) *
(size_seq + 1));
result = (PyObject **) PyMem_Malloc(sizeof(PyObject *) *
(size_seq + 1));

if (result) {
result[size_seq] = nullptr;
Expand All @@ -704,7 +711,7 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
for (Py_ssize_t j = 0; j < i; ++j)
Py_DECREF(result[j]);

PyObject_Free(result);
PyMem_Free(result);
result = nullptr;
break;
}
Expand All @@ -716,7 +723,7 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
PyObject **ptr = (PyObject **) PyCapsule_GetPointer(o, nullptr);
for (size_t i = 0; ptr[i] != nullptr; ++i)
Py_DECREF(ptr[i]);
PyObject_Free(ptr);
PyMem_Free(ptr);
});

if (temp) {
Expand All @@ -726,7 +733,7 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
for (Py_ssize_t i = 0; i < size_seq; ++i)
Py_DECREF(result[i]);

PyObject_Free(result);
PyMem_Free(result);
result = nullptr;
}
}
Expand Down Expand Up @@ -763,14 +770,16 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size,
if (size == 0)
result = (PyObject **) 1;
}
# if !defined(NB_FREE_THREADED) // Require immutable holder in free-threaded mode
} else if (PyList_CheckExact(seq)) {
if (size == (size_t) PyList_GET_SIZE(seq)) {
result = ((PyListObject *) seq)->ob_item;
if (size == 0) // ditto
result = (PyObject **) 1;
}
# endif
} else if (PySequence_Check(seq)) {
temp = PySequence_Fast(seq, "");
temp = PySequence_Tuple(seq);

if (temp)
result = seq_get_with_size(temp, size, temp_out);
Expand All @@ -785,7 +794,7 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size,

if (size == (size_t) size_seq) {
result =
(PyObject **) PyObject_Malloc(sizeof(PyObject *) * (size + 1));
(PyObject **) PyMem_Malloc(sizeof(PyObject *) * (size + 1));

if (result) {
result[size] = nullptr;
Expand All @@ -799,7 +808,7 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size,
for (Py_ssize_t j = 0; j < i; ++j)
Py_DECREF(result[j]);

PyObject_Free(result);
PyMem_Free(result);
result = nullptr;
break;
}
Expand All @@ -811,15 +820,15 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size,
PyObject **ptr = (PyObject **) PyCapsule_GetPointer(o, nullptr);
for (size_t i = 0; ptr[i] != nullptr; ++i)
Py_DECREF(ptr[i]);
PyObject_Free(ptr);
PyMem_Free(ptr);
});

if (!temp) {
PyErr_Clear();
for (Py_ssize_t i = 0; i < size_seq; ++i)
Py_DECREF(result[i]);

PyObject_Free(result);
PyMem_Free(result);
result = nullptr;
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

// Protected by internals->mutex in free-threaded builds
Buffer buf(128);

NAMESPACE_END(detail)
Expand Down Expand Up @@ -116,13 +117,15 @@ void python_error::restore() noexcept {
#endif

const char *python_error::what() const noexcept {
using detail::buf;
using namespace nanobind::detail;

// Return the existing error message if already computed once
if (m_what)
return m_what;

gil_scoped_acquire acq;
// 'buf' is protected by internals->mutex in free-threaded builds
lock_internals guard(internals);

// Try again with GIL held
if (m_what)
Expand All @@ -147,7 +150,7 @@ const char *python_error::what() const noexcept {
#if defined(Py_LIMITED_API) || defined(PYPY_VERSION)
object mod = module_::import_("traceback"),
result = mod.attr("format_exception")(exc_type, exc_value, exc_traceback);
m_what = detail::strdup_check(borrow<str>(str("\n").attr("join")(result)).c_str());
m_what = strdup_check(borrow<str>(str("\n").attr("join")(result)).c_str());
#else
buf.clear();
if (exc_traceback.is_valid()) {
Expand All @@ -160,7 +163,7 @@ const char *python_error::what() const noexcept {
PyFrameObject *frame = to->tb_frame;
Py_XINCREF(frame);

std::vector<PyFrameObject *, detail::py_allocator<PyFrameObject *>> frames;
std::vector<PyFrameObject *, py_allocator<PyFrameObject *>> frames;

while (frame) {
frames.push_back(frame);
Expand Down
16 changes: 10 additions & 6 deletions src/implicit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ NAMESPACE_BEGIN(detail)

void implicitly_convertible(const std::type_info *src,
const std::type_info *dst) noexcept {
type_data *t = nb_type_c2p(internals, dst);
nb_internals *internals_ = internals;
type_data *t = nb_type_c2p(internals_, dst);
check(t, "nanobind::detail::implicitly_convertible(src=%s, dst=%s): "
"destination type unknown!", type_name(src), type_name(dst));

lock_internals guard(internals_);
size_t size = 0;

if (t->flags & (uint32_t) type_flags::has_implicit_conversions) {
Expand All @@ -30,23 +32,25 @@ void implicitly_convertible(const std::type_info *src,
t->flags |= (uint32_t) type_flags::has_implicit_conversions;
}

void **data = (void **) malloc(sizeof(void *) * (size + 2));
void **data = (void **) PyMem_Malloc(sizeof(void *) * (size + 2));

if (size)
memcpy(data, t->implicit.cpp, size * sizeof(void *));
data[size] = (void *) src;
data[size + 1] = nullptr;
free(t->implicit.cpp);
PyMem_Free(t->implicit.cpp);
t->implicit.cpp = (decltype(t->implicit.cpp)) data;
}

void implicitly_convertible(bool (*predicate)(PyTypeObject *, PyObject *,
cleanup_list *),
const std::type_info *dst) noexcept {
type_data *t = nb_type_c2p(internals, dst);
nb_internals *internals_ = internals;
type_data *t = nb_type_c2p(internals_, dst);
check(t, "nanobind::detail::implicitly_convertible(src=<predicate>, dst=%s): "
"destination type unknown!", type_name(dst));

lock_internals guard(internals_);
size_t size = 0;

if (t->flags & (uint32_t) type_flags::has_implicit_conversions) {
Expand All @@ -58,12 +62,12 @@ void implicitly_convertible(bool (*predicate)(PyTypeObject *, PyObject *,
t->flags |= (uint32_t) type_flags::has_implicit_conversions;
}

void **data = (void **) malloc(sizeof(void *) * (size + 2));
void **data = (void **) PyMem_Malloc(sizeof(void *) * (size + 2));
if (size)
memcpy(data, t->implicit.py, size * sizeof(void *));
data[size] = (void *) predicate;
data[size + 1] = nullptr;
free(t->implicit.py);
PyMem_Free(t->implicit.py);
t->implicit.py = (decltype(t->implicit.py)) data;
}

Expand Down
33 changes: 24 additions & 9 deletions src/nb_enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@ using enum_map = tsl::robin_map<int64_t, int64_t, int64_hash>;

PyObject *enum_create(enum_init_data *ed) noexcept {
// Update hash table that maps from std::type_info to Python type
auto [it, success] = internals->type_c2p_slow.try_emplace(ed->type, nullptr);
if (!success) {
PyErr_WarnFormat(PyExc_RuntimeWarning, 1, "nanobind: type '%s' was already registered!\n", ed->name);
PyObject *tp = (PyObject *) it->second->type_py;
Py_INCREF(tp);
return tp;
nb_internals *internals_ = internals;
bool success;
nb_type_map_slow::iterator it;

{
lock_internals guard(internals_);
std::tie(it, success) = internals->type_c2p_slow.try_emplace(ed->type, nullptr);
if (!success) {
PyErr_WarnFormat(PyExc_RuntimeWarning, 1,
"nanobind: type '%s' was already registered!\n",
ed->name);
PyObject *tp = (PyObject *) it->second->type_py;
Py_INCREF(tp);
return tp;
}
}

handle scope(ed->scope);
Expand Down Expand Up @@ -77,15 +86,21 @@ PyObject *enum_create(enum_init_data *ed) noexcept {

it.value() = t;

internals->type_c2p_fast[ed->type] = t;
internals->type_c2p_slow[ed->type] = t;
{
lock_internals guard(internals_);
internals_->type_c2p_slow[ed->type] = t;

#if !defined(NB_FREE_THREADED)
internals_->type_c2p_fast[ed->type] = t;
#endif
}

result.attr("__nb_enum__") = capsule(t, [](void *p) noexcept {
type_init_data *t = (type_init_data *) p;
delete (enum_map *) t->enum_tbl.fwd;
delete (enum_map *) t->enum_tbl.rev;
nb_type_unregister(t);
free((char*)t->name);
free((char*) t->name);
delete t;
});

Expand Down
Loading

0 comments on commit 6174eb8

Please sign in to comment.