Skip to content

Commit

Permalink
Do not repeat identical docstrings in long overload chains
Browse files Browse the repository at this point in the history
In auto-generated C++ bindings, we might sometimes have a long chain of
overloads that bind a function with differently typed arguments, and
which are all created with the same per-overload docstring.

Previously, nanobind extended this into a long and highly redundant
combined docstring that showed both type signatures and docstrings
multiple times.

```
f(.. signature 1..)
f(.. signature 2..)

This is an overloaded function.

1. ``f(.. signature 1..)``

Docstring

1. ``f(.. signature 2..)``

Docstring
```

Nanobind now detects situations where the per-overload docstrings
are all uniform, in which case it generates

```
f(.. signature 1..)
f(.. signature 2..)

Docstring
```
  • Loading branch information
wjakob committed Sep 27, 2024
1 parent 8ce0dee commit b1088be
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 24 deletions.
72 changes: 48 additions & 24 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,18 @@ PyObject *nb_func_new(const void *in_) noexcept {
}

// Create a new function and destroy the old one
Py_ssize_t to_copy = func_prev ? Py_SIZE(func_prev) : 0;
Py_ssize_t prev_overloads = func_prev ? Py_SIZE(func_prev) : 0;
nb_func *func = (nb_func *) PyType_GenericAlloc(
is_method ? internals_->nb_method : internals_->nb_func, to_copy + 1);
is_method ? internals_->nb_method : internals_->nb_func, prev_overloads + 1);
check(func, "nb::detail::nb_func_new(\"%s\"): alloc. failed (1).",
name_cstr);

maybe_make_immortal((PyObject *) func);

// Check if the complex dispatch loop is needed
bool complex_call = has_keep_alive || has_var_kwargs || has_var_args || f->nargs >= NB_MAXARGS_SIMPLE;
bool complex_call = has_keep_alive || has_var_kwargs || has_var_args ||
f->nargs >= NB_MAXARGS_SIMPLE;

if (has_args) {
for (size_t i = is_method; i < f->nargs; ++i) {
arg_data &a = args_in[i - is_method];
Expand All @@ -302,15 +304,22 @@ PyObject *nb_func_new(const void *in_) noexcept {
}

uint32_t max_nargs = f->nargs;

const char *prev_doc = nullptr;

if (func_prev) {
complex_call |= ((nb_func *) func_prev)->complex_call;
max_nargs = std::max(max_nargs, ((nb_func *) func_prev)->max_nargs);
nb_func *nb_func_prev = (nb_func *) func_prev;
complex_call |= nb_func_prev->complex_call;
max_nargs = std::max(max_nargs, nb_func_prev->max_nargs);

func_data *cur = nb_func_data(func),
*prev = nb_func_data(func_prev);

memcpy(cur, prev, sizeof(func_data) * to_copy);
memset(prev, 0, sizeof(func_data) * to_copy);
if (nb_func_prev->doc_uniform)
prev_doc = prev->doc;

memcpy(cur, prev, sizeof(func_data) * prev_overloads);
memset(prev, 0, sizeof(func_data) * prev_overloads);

((PyVarObject *) func_prev)->ob_size = 0;

Expand All @@ -335,14 +344,25 @@ PyObject *nb_func_new(const void *in_) noexcept {
"nanobind::detail::nb_func_new(): internal update failed (2)!");
#endif

func_data *fc = nb_func_data(func) + to_copy;
func_data *fc = nb_func_data(func) + prev_overloads;
memcpy(fc, f, sizeof(func_data_prelim<0>));
if (has_doc) {
if (fc->doc[0] == '\n')
fc->doc++;
fc->doc = strdup_check(fc->doc);
if (fc->doc[0] == '\0') {
fc->doc = nullptr;
fc->flags &= ~(uint32_t) func_flags::has_doc;
has_doc = true;
} else {
fc->doc = strdup_check(fc->doc);
}
}

// Detect when an entire overload chain has the dame docstring
func->doc_uniform =
(has_doc && ((prev_overloads == 0) ||
(prev_doc && strcmp(fc->doc, prev_doc) == 0)));

if (is_constructor)
fc->flags |= (uint32_t) func_flags::is_constructor;
if (has_args)
Expand Down Expand Up @@ -1241,7 +1261,8 @@ PyObject *nb_func_get_nb_signature(PyObject *self, void *) {
docstr = item = sigstr = defaults = nullptr;

const func_data *fi = f + i;
if (fi->flags & (uint32_t) func_flags::has_doc && fi->doc[0] != '\0') {
if ((fi->flags & (uint32_t) func_flags::has_doc) &&
(!((nb_func *) self)->doc_uniform || i == 0)) {
docstr = PyUnicode_FromString(fi->doc);
} else {
docstr = Py_None;
Expand Down Expand Up @@ -1314,33 +1335,36 @@ PyObject *nb_func_get_doc(PyObject *self, void *) {

buf.clear();

size_t doc_count = 0;
bool doc_found = false;

for (uint32_t i = 0; i < count; ++i) {
const func_data *fi = f + i;
nb_func_render_signature(fi);
buf.put('\n');
if ((fi->flags & (uint32_t) func_flags::has_doc) && fi->doc[0] != '\0')
doc_count++;
doc_found |= fi->flags & (uint32_t) func_flags::has_doc;
}

if (doc_count > 1)
buf.put("\nOverloaded function.\n");

for (uint32_t i = 0; i < count; ++i) {
const func_data *fi = f + i;

if ((fi->flags & (uint32_t) func_flags::has_doc) && fi->doc[0] != '\0') {
if (doc_found) {
if (((nb_func *) self)->doc_uniform) {
buf.put('\n');
buf.put_dstr(f->doc);
buf.put('\n');
} else {
buf.put("\nOverloaded function.\n");
for (uint32_t i = 0; i < count; ++i) {
const func_data *fi = f + i;

if (doc_count > 1) {
buf.put('\n');
buf.put_uint32(i + 1);
buf.put(". ``");
nb_func_render_signature(fi);
buf.put("``\n\n");
}

buf.put_dstr(fi->doc);
buf.put('\n');
if (fi->flags & (uint32_t) func_flags::has_doc) {
buf.put_dstr(fi->doc);
buf.put('\n');
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/nb_internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct nb_func {
PyObject* (*vectorcall)(PyObject *, PyObject * const*, size_t, PyObject *);
uint32_t max_nargs; // maximum value of func_data::nargs for any overload
bool complex_call;
bool doc_uniform;
};

/// Python object representing a `nb_ndarray` (which wraps a DLPack ndarray)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ NB_MODULE(test_functions_ext, m) {
#endif
first_overload.reset();

// Test an overload chain that always repeats the same docstring
m.def("test_05b", [](int) -> int { return 1; }, "doc_1");
m.def("test_05b", [](float) -> int { return 2; }, "doc_1");

/// Function raising an exception
m.def("test_06", []() { throw std::runtime_error("oops!"); });

Expand Down
7 changes: 7 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def test05_signature():
"doc_2"
)

assert t.test_05b.__doc__ == (
"test_05b(arg: int, /) -> int\n"
"test_05b(arg: float, /) -> int\n"
"\n"
"doc_1"
)

assert t.test_07.__doc__ == (
f"test_07(arg0: int, arg1: int, /, *args, **kwargs) -> {TYPING_TUPLE}[int, int]\n"
f"test_07(a: int, b: int, *myargs, **mykwargs) -> {TYPING_TUPLE}[int, int]"
Expand Down
7 changes: 7 additions & 0 deletions tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def test_05(arg: int, /) -> int:
def test_05(arg: float, /) -> int:
"""doc_2"""

@overload
def test_05b(arg: int, /) -> int:
"""doc_1"""

@overload
def test_05b(arg: float, /) -> int: ...

def test_06() -> None: ...

@overload
Expand Down

0 comments on commit b1088be

Please sign in to comment.