Skip to content

Commit

Permalink
Fix extraction of class name with new parameter syntax (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
yosh-matsuda authored Aug 29, 2024
1 parent 928e3b1 commit 7c94789
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,11 +828,16 @@ NB_NOINLINE char *extract_name(const char *cmd, const char *prefix, const char *
cmd, s, prefix);
p += prefix_len;

// Find the opening parenthesis
// Find the opening parenthesis or bracket
const char *p2 = strchr(p, '(');
const char *p3 = strchr(p, '[');
if (p2 == nullptr)
p2 = p3;
else if (p3 != nullptr)
p2 = p2 < p3 ? p2 : p3;
check(p2 != nullptr,
"%s(): last line of custom signature \"%s\" must contain an opening "
"parenthesis (\"(\")!", cmd, s);
"parenthesis (\"(\") or bracket (\"[\")!", cmd, s);

// A few sanity checks
size_t len = strlen(p);
Expand Down
7 changes: 7 additions & 0 deletions tests/test_typing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ NB_MODULE(test_typing_ext, m) {
nb::class_<WrapperFoo>(m, "WrapperFoo", wrapper[nb::type<Foo>()]);
#endif

// Type parameter syntax for Python 3.12+
struct WrapperTypeParam { };
nb::class_<WrapperTypeParam>(m, "WrapperTypeParam",
nb::sig("class WrapperTypeParam[T]"));
m.def("list_front", [](nb::list l) { return l[0]; },
nb::sig("def list_front[T](arg: list[T], /) -> T"));

// Some statements that will be modified by the pattern file
m.def("remove_me", []{});
m.def("tweak_me", [](nb::object o) { return o; }, "prior docstring\nremains preserved");
Expand Down
5 changes: 5 additions & 0 deletions tests/test_typing_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ class Wrapper(Generic[T]):
class WrapperFoo(Wrapper[Foo]):
pass

class WrapperTypeParam[T]:
pass

def f() -> None: ...

f_alias = f

def list_front[T](arg: list[T], /) -> T: ...

def makeNestedClass() -> py_stub_test.AClass.NestedClass: ...

pytree: dict = {'a' : ('b', [123])}
Expand Down

0 comments on commit 7c94789

Please sign in to comment.