Skip to content

Commit

Permalink
stubgen: reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Feb 27, 2024
1 parent 8cefc1b commit 0bdac61
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,36 +517,44 @@ def put(self, value, module=None, name=None, parent=None):
query_v[1] += 1
for l in query_v[0]:
ls = l.strip()
if ls == '\\doc':
if ls == "\\doc":
# Docstring reference
tp = type(value)
if tp.__module__ == 'nanobind' and \
tp.__name__ in ('nb_func', 'nb_method'):
if tp.__module__ == "nanobind" and tp.__name__ in (
"nb_func",
"nb_method",
):
for tp_i in value.__nb_signature__:
doc = tp_i[1]
if doc:
break
else:
doc = getattr(value, '__doc__', None)
doc = getattr(value, "__doc__", None)
self.depth += 1
if doc and self.include_docstrings:
self.put_docstr(doc)
else:
self.write_ln('pass')
self.write_ln("pass")
self.depth -= 1
continue
elif ls.startswith('\\from '):
items = ls[5:].split(' import ')
elif ls.startswith("\\from "):
items = ls[5:].split(" import ")
if len(items) != 2:
raise RuntimeError(f"Could not parse import declaration {ls}")
for item in items[1].strip('()').split(','):
item = item.split(' as ')
import_module, import_name = items[0].strip(), item[0].strip()
raise RuntimeError(
f"Could not parse import declaration {ls}"
)
for item in items[1].strip("()").split(","):
item = item.split(" as ")
import_module, import_name = (
items[0].strip(),
item[0].strip(),
)
import_as = item[1].strip() if len(item) > 1 else None
self.import_object(import_module, import_name, import_as)
self.import_object(
import_module, import_name, import_as
)
continue


groups = match.groups()
for i in reversed(range(len(groups))):
l = l.replace(f"\\{i+1}", groups[i])
Expand Down Expand Up @@ -595,7 +603,6 @@ def put(self, value, module=None, name=None, parent=None):
):
return

tp = type(value)
tp_mod, tp_name = tp.__module__, tp.__name__

if inspect.ismodule(value):
Expand Down Expand Up @@ -638,9 +645,9 @@ def import_object(self, module, name, as_name=None):
# Rewrite module name if this is relative import from a submodule
if self.module and module.startswith(self.module.__name__):
module_short = module[len(self.module.__name__) :]
if not name and as_name and module_short[0] == '.':
if not name and as_name and module_short[0] == ".":
name = as_name = module_short[1:]
module_short = '.'
module_short = "."
else:
module_short = module

Expand All @@ -667,7 +674,7 @@ def import_object(self, module, name, as_name=None):
break
value = getattr(self.module, final_name)
try:
if module == '.':
if module == ".":
mod_o = self.module
else:
mod_o = importlib.import_module(module)
Expand Down Expand Up @@ -765,7 +772,7 @@ def get(self):
imports = self.imports[module]
items = []

for ((k, v1), v2) in imports.items():
for (k, v1), v2 in imports.items():
if k == None:
if v1 and v1 != module:
s += f"import {module} as {v1}\n"
Expand Down Expand Up @@ -942,7 +949,7 @@ def add_pattern(query, pattern):
# Exactly 1 empty line at the end
while pattern and pattern[-1].isspace():
pattern.pop()
pattern.append('')
pattern.append("")

# Identify deletions (replacement by only whitespace)
if all((p.isspace() or len(p) == 0 for p in pattern)):
Expand Down

0 comments on commit 0bdac61

Please sign in to comment.