Skip to content

Commit

Permalink
style: Renamed variables to avoid built-in names (#73)
Browse files Browse the repository at this point in the history
* style: Renamed variables to avoid built-in names

* style: Fixed black
frgfm authored Aug 3, 2022
1 parent d8f0d68 commit 4e327c2
Showing 5 changed files with 129 additions and 131 deletions.
16 changes: 8 additions & 8 deletions torchscan/crawler.py
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ def crawl_module(

# Hook definition
def _hook_info(module: Module, name: str) -> None:
def _pre_hook(module: Module, input: torch.Tensor) -> None:
def _pre_hook(module: Module, inp: torch.Tensor) -> None:
"""Pre-forward hook"""
# Check that another hook has not been triggered at this forward stage
if not pre_hook_tracker[id(module)]["is_used"] and (
@@ -123,7 +123,7 @@ def _pre_hook(module: Module, input: torch.Tensor) -> None:
name=name.rpartition(".")[-1],
depth=len(name.split(".")) - 1,
type=module.__class__.__name__,
input_shape=(-1, *input[0][0].shape[1:]),
input_shape=(-1, *inp[0][0].shape[1:]),
output_shape=None,
grad_params=grad_params,
nograd_params=nograd_params,
@@ -150,7 +150,7 @@ def _pre_hook(module: Module, input: torch.Tensor) -> None:
pre_hook_tracker[id(module)]["current"] = 0
pre_hook_tracker[id(module)]["is_used"] = False

def _fwd_hook(module: Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
def _fwd_hook(module: Module, inputs: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:
"""Post-forward hook"""

# Check that another hook has not been triggered at this forward stage
@@ -173,13 +173,13 @@ def _fwd_hook(module: Module, inputs: Tuple[torch.Tensor, ...], output: torch.Te
current_rf, current_stride, current_padding = 1.0, 1.0, 0.0
else:
# Compute stats for standalone layers
tot_flops = module_flops(module, inputs, output)
tot_macs = module_macs(module, inputs[0], output)
tot_dmas = module_dmas(module, inputs[0], output)
current_rf, current_stride, current_padding = module_rf(module, inputs[0], output)
tot_flops = module_flops(module, inputs, out)
tot_macs = module_macs(module, inputs[0], out)
tot_dmas = module_dmas(module, inputs[0], out)
current_rf, current_stride, current_padding = module_rf(module, inputs[0], out)

# Update layer information
info[fw_idx]["output_shape"] = (-1, *output.shape[1:])
info[fw_idx]["output_shape"] = (-1, *out.shape[1:])
# Add them, since some modules can be used several times
info[fw_idx]["flops"] = tot_flops
info[fw_idx]["macs"] = tot_macs
48 changes: 24 additions & 24 deletions torchscan/modules/flops.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,13 @@
__all__ = ["module_flops"]


def module_flops(module: Module, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def module_flops(module: Module, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""Estimate the number of floating point operations performed by the module
Args:
module: PyTorch module
inputs: input to the module
output: output of the module
out: output of the module
Returns:
number of FLOPs
"""
@@ -46,19 +46,19 @@ def module_flops(module: Module, inputs: Tuple[Tensor, ...], output: Tensor) ->
elif isinstance(module, nn.Sigmoid):
return flops_sigmoid(module, inputs)
elif isinstance(module, _ConvTransposeNd):
return flops_convtransposend(module, inputs, output)
return flops_convtransposend(module, inputs, out)
elif isinstance(module, _ConvNd):
return flops_convnd(module, inputs, output)
return flops_convnd(module, inputs, out)
elif isinstance(module, _BatchNorm):
return flops_bn(module, inputs)
elif isinstance(module, _MaxPoolNd):
return flops_maxpool(module, inputs, output)
return flops_maxpool(module, inputs, out)
elif isinstance(module, _AvgPoolNd):
return flops_avgpool(module, inputs, output)
return flops_avgpool(module, inputs, out)
elif isinstance(module, _AdaptiveMaxPoolNd):
return flops_adaptive_maxpool(module, inputs, output)
return flops_adaptive_maxpool(module, inputs, out)
elif isinstance(module, _AdaptiveAvgPoolNd):
return flops_adaptive_avgpool(module, inputs, output)
return flops_adaptive_avgpool(module, inputs, out)
elif isinstance(module, nn.Dropout):
return flops_dropout(module, inputs)
elif isinstance(module, nn.Transformer):
@@ -131,20 +131,20 @@ def flops_dropout(module: nn.Dropout, inputs: Tuple[Tensor, ...]) -> int:
return 0


def flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.conv._ConvTranposeNd`"""

# Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
# Define min and max sizes
padding_flops = len(module.kernel_size) * 8

# Once padding is determined, the operations are almost identical to those of a convolution
conv_flops = flops_convnd(module, inputs, output)
conv_flops = flops_convnd(module, inputs, out)

return padding_flops + conv_flops


def flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.conv._ConvNd`"""

# For each position, # mult = kernel size, # adds = kernel size - 1
@@ -153,10 +153,10 @@ def flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], output: Tensor) ->
effective_in_chan = inputs[0].shape[1] // module.groups
# N * flops + (N - 1) additions
window_flops = effective_in_chan * window_flops_per_chan + (effective_in_chan - 1)
conv_flops = output.numel() * window_flops
conv_flops = out.numel() * window_flops

# Each output element gets a bias addition
bias_flops = output.numel() if module.bias is not None else 0
bias_flops = out.numel() if module.bias is not None else 0

return conv_flops + bias_flops

@@ -189,48 +189,48 @@ def flops_bn(module: _BatchNorm, inputs: Tuple[Tensor, ...]) -> int:
return bn_flops + tracking_flops


def flops_maxpool(module: _MaxPoolNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_maxpool(module: _MaxPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.pooling._MaxPoolNd`"""

k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

# for each spatial output element, check max element in kernel scope
return output.numel() * (k_size - 1)
return out.numel() * (k_size - 1)


def flops_avgpool(module: _AvgPoolNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_avgpool(module: _AvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.pooling._AvgPoolNd`"""

k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

# for each spatial output element, sum elements in kernel scope and div by kernel size
return output.numel() * (k_size - 1 + inputs[0].ndim - 2)
return out.numel() * (k_size - 1 + inputs[0].ndim - 2)


def flops_adaptive_maxpool(module: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_adaptive_maxpool(module: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`"""

# Approximate kernel_size using ratio of spatial shapes between input and output
kernel_size = tuple(
i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
for i_size, o_size in zip(inputs[0].shape[2:], output.shape[2:])
for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:])
)

# for each spatial output element, check max element in kernel scope
return output.numel() * (reduce(mul, kernel_size) - 1)
return out.numel() * (reduce(mul, kernel_size) - 1)


def flops_adaptive_avgpool(module: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...], output: Tensor) -> int:
def flops_adaptive_avgpool(module: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
"""FLOPs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`"""

# Approximate kernel_size using ratio of spatial shapes between input and output
kernel_size = tuple(
i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
for i_size, o_size in zip(inputs[0].shape[2:], output.shape[2:])
for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:])
)

# for each spatial output element, sum elements in kernel scope and div by kernel size
return output.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))
return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))


def flops_layernorm(module: nn.LayerNorm, inputs: Tuple[Tensor, ...]) -> int:
@@ -254,7 +254,7 @@ def flops_mha(module: nn.MultiheadAttention, inputs: Tuple[Tensor, ...]) -> int:
"""FLOPs estimation for `torch.nn.MultiheadAttention`"""

# Input projection
q, k, v = inputs[:3]
q, k, _ = inputs[:3]
batch_size = q.shape[1]
if module._qkv_same_embed_dim:
tot_flops = 3 * flops_linear(
64 changes: 32 additions & 32 deletions torchscan/modules/macs.py
Original file line number Diff line number Diff line change
@@ -16,79 +16,79 @@
__all__ = ["module_macs"]


def module_macs(module: Module, input: Tensor, output: Tensor) -> int:
def module_macs(module: Module, inp: Tensor, out: Tensor) -> int:
"""Estimate the number of multiply-accumulation operations performed by the module
Args:
module (torch.nn.Module): PyTorch module
input (torch.Tensor): input to the module
output (torch.Tensor): output of the module
inp (torch.Tensor): input to the module
out (torch.Tensor): output of the module
Returns:
int: number of MACs
"""
if isinstance(module, nn.Linear):
return macs_linear(module, input, output)
return macs_linear(module, inp, out)
elif isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)):
return 0
elif isinstance(module, _ConvTransposeNd):
return macs_convtransposend(module, input, output)
return macs_convtransposend(module, inp, out)
elif isinstance(module, _ConvNd):
return macs_convnd(module, input, output)
return macs_convnd(module, inp, out)
elif isinstance(module, _BatchNorm):
return macs_bn(module, input, output)
return macs_bn(module, inp, out)
elif isinstance(module, _MaxPoolNd):
return macs_maxpool(module, input, output)
return macs_maxpool(module, inp, out)
elif isinstance(module, _AvgPoolNd):
return macs_avgpool(module, input, output)
return macs_avgpool(module, inp, out)
elif isinstance(module, _AdaptiveMaxPoolNd):
return macs_adaptive_maxpool(module, input, output)
return macs_adaptive_maxpool(module, inp, out)
elif isinstance(module, _AdaptiveAvgPoolNd):
return macs_adaptive_avgpool(module, input, output)
return macs_adaptive_avgpool(module, inp, out)
elif isinstance(module, nn.Dropout):
return 0
else:
warnings.warn(f"Module type not supported: {module.__class__.__name__}")
return 0


def macs_linear(module: nn.Linear, input: Tensor, output: Tensor) -> int:
def macs_linear(module: nn.Linear, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.Linear`"""

# batch size * out_chan * macs_per_elt (bias already counted in accumulation)
mm_mac = module.in_features * reduce(mul, output.shape)
mm_mac = module.in_features * reduce(mul, out.shape)

return mm_mac


def macs_convtransposend(module: _ConvTransposeNd, input: Tensor, output: Tensor) -> int:
def macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.conv._ConvTransposeNd`"""

# Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
# Define min and max sizes, then subtract them
padding_macs = len(module.kernel_size) * 4

# Rest of the operations are almost identical to a convolution (given the padding)
conv_macs = macs_convnd(module, input, output)
conv_macs = macs_convnd(module, inp, out)

return padding_macs + conv_macs


def macs_convnd(module: _ConvNd, input: Tensor, output: Tensor) -> int:
def macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.conv._ConvNd`"""

# For each position, # mult = kernel size, # adds = kernel size - 1
window_macs_per_chan = reduce(mul, module.kernel_size)
# Connections to input channels is controlled by the group parameter
effective_in_chan = input.shape[1] // module.groups
effective_in_chan = inp.shape[1] // module.groups
# N * mac
window_mac = effective_in_chan * window_macs_per_chan
conv_mac = output.numel() * window_mac
conv_mac = out.numel() * window_mac

# bias already counted in accumulation
return conv_mac


def macs_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int:
def macs_bn(module: _BatchNorm, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""

# sub mean, div by denom
@@ -97,13 +97,13 @@ def macs_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int:
scale_mac = 1 if module.affine else 0

# Sum everything up
bn_mac = input.numel() * (norm_mac + scale_mac)
bn_mac = inp.numel() * (norm_mac + scale_mac)

# Count tracking stats update ops
# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101
tracking_mac = 0
b = input.shape[0]
num_spatial_elts = input.shape[2:].numel()
b = inp.shape[0]
num_spatial_elts = inp.shape[2:].numel()
if module.track_running_stats and module.training:
# running_mean: by channel, sum value and div by batch size
tracking_mac += module.num_features * (b * num_spatial_elts - 1)
@@ -116,45 +116,45 @@ def macs_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int:
return bn_mac + tracking_mac


def macs_maxpool(module: _MaxPoolNd, input: Tensor, output: Tensor) -> int:
def macs_maxpool(module: _MaxPoolNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.pooling._MaxPoolNd`"""

k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

# for each spatial output element, check max element in kernel scope
return output.numel() * (k_size - 1)
return out.numel() * (k_size - 1)


def macs_avgpool(module: _AvgPoolNd, input: Tensor, output: Tensor) -> int:
def macs_avgpool(module: _AvgPoolNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.pooling._AvgPoolNd`"""

k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

# for each spatial output element, sum elements in kernel scope and div by kernel size
return output.numel() * (k_size - 1 + input.ndim - 2)
return out.numel() * (k_size - 1 + inp.ndim - 2)


def macs_adaptive_maxpool(module: _AdaptiveMaxPoolNd, input: Tensor, output: Tensor) -> int:
def macs_adaptive_maxpool(module: _AdaptiveMaxPoolNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`"""

# Approximate kernel_size using ratio of spatial shapes between input and output
kernel_size = tuple(
i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
for i_size, o_size in zip(input.shape[2:], output.shape[2:])
for i_size, o_size in zip(inp.shape[2:], out.shape[2:])
)

# for each spatial output element, check max element in kernel scope
return output.numel() * (reduce(mul, kernel_size) - 1)
return out.numel() * (reduce(mul, kernel_size) - 1)


def macs_adaptive_avgpool(module: _AdaptiveAvgPoolNd, input: Tensor, output: Tensor) -> int:
def macs_adaptive_avgpool(module: _AdaptiveAvgPoolNd, inp: Tensor, out: Tensor) -> int:
"""MACs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`"""

# Approximate kernel_size using ratio of spatial shapes between input and output
kernel_size = tuple(
i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
for i_size, o_size in zip(input.shape[2:], output.shape[2:])
for i_size, o_size in zip(inp.shape[2:], out.shape[2:])
)

# for each spatial output element, sum elements in kernel scope and div by kernel size
return output.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))
return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))
108 changes: 54 additions & 54 deletions torchscan/modules/memory.py
Original file line number Diff line number Diff line change
@@ -17,44 +17,44 @@
__all__ = ["module_dmas"]


def module_dmas(module: Module, input: Tensor, output: Tensor) -> int:
def module_dmas(module: Module, inp: Tensor, out: Tensor) -> int:
"""Estimate the number of direct memory accesses by the module.
The implementation overhead is neglected.
Args:
module (torch.nn.Module): PyTorch module
input (torch.Tensor): input to the module
output (torch.Tensor): output of the module
inp (torch.Tensor): input to the module
out (torch.Tensor): output of the module
Returns:
int: number of DMAs
"""

if isinstance(module, nn.Identity):
return dmas_identity(module, input, output)
return dmas_identity(module, inp, out)
elif isinstance(module, nn.Flatten):
return dmas_flatten(module, input, output)
return dmas_flatten(module, inp, out)
elif isinstance(module, nn.Linear):
return dmas_linear(module, input, output)
return dmas_linear(module, inp, out)
elif isinstance(module, (nn.ReLU, nn.ReLU6)):
return dmas_relu(module, input, output)
return dmas_relu(module, inp, out)
elif isinstance(module, (nn.ELU, nn.LeakyReLU)):
return dmas_act_single_param(module, input, output)
return dmas_act_single_param(module, inp, out)
elif isinstance(module, nn.Sigmoid):
return dmas_sigmoid(module, input, output)
return dmas_sigmoid(module, inp, out)
elif isinstance(module, nn.Tanh):
return dmas_tanh(module, input, output)
return dmas_tanh(module, inp, out)
elif isinstance(module, _ConvTransposeNd):
return dmas_convtransposend(module, input, output)
return dmas_convtransposend(module, inp, out)
elif isinstance(module, _ConvNd):
return dmas_convnd(module, input, output)
return dmas_convnd(module, inp, out)
elif isinstance(module, _BatchNorm):
return dmas_bn(module, input, output)
return dmas_bn(module, inp, out)
elif isinstance(module, (_MaxPoolNd, _AvgPoolNd)):
return dmas_pool(module, input, output)
return dmas_pool(module, inp, out)
elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
return dmas_adaptive_pool(module, input, output)
return dmas_adaptive_pool(module, inp, out)
elif isinstance(module, nn.Dropout):
return dmas_dropout(module, input, output)
return dmas_dropout(module, inp, out)
else:
warnings.warn(f"Module type not supported: {module.__class__.__name__}")
return 0
@@ -72,83 +72,83 @@ def num_params(module: Module) -> int:
return sum(p.data.numel() for p in module.parameters())


def dmas_identity(module: nn.Identity, input: Tensor, output: Tensor) -> int:
def dmas_identity(module: nn.Identity, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Identity`"""

return input.numel()
return inp.numel()


def dmas_flatten(module: nn.Flatten, input: Tensor, output: Tensor) -> int:
def dmas_flatten(module: nn.Flatten, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Flatten`"""

return 2 * input.numel()
return 2 * inp.numel()


def dmas_linear(module: nn.Linear, input: Tensor, output: Tensor) -> int:
def dmas_linear(module: nn.Linear, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Linear`"""

input_dma = input.numel()
input_dma = inp.numel()
# Access weight and bias
ops_dma = num_params(module)
output_dma = output.numel()
output_dma = out.numel()

return input_dma + ops_dma + output_dma


def dmas_relu(module: Union[nn.ReLU, nn.ReLU6], input: Tensor, output: Tensor) -> int:
def dmas_relu(module: Union[nn.ReLU, nn.ReLU6], inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.ReLU`"""

input_dma = input.numel()
output_dma = 0 if module.inplace else output.numel()
input_dma = inp.numel()
output_dma = 0 if module.inplace else out.numel()

return input_dma + output_dma


def dmas_act_single_param(module: Union[nn.ELU, nn.LeakyReLU], input: Tensor, output: Tensor) -> int:
def dmas_act_single_param(module: Union[nn.ELU, nn.LeakyReLU], inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for activations with single parameter"""

input_dma = input.numel()
input_dma = inp.numel()
# Access alpha, slope or other
ops_dma = 1
output_dma = 0 if module.inplace else output.numel()
output_dma = 0 if module.inplace else out.numel()

return input_dma + ops_dma + output_dma


def dmas_sigmoid(module: nn.Sigmoid, input: Tensor, output: Tensor) -> int:
def dmas_sigmoid(module: nn.Sigmoid, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Sigmoid`"""

# Access for both exp
input_dma = input.numel()
output_dma = output.numel()
input_dma = inp.numel()
output_dma = out.numel()

return input_dma + output_dma


def dmas_tanh(module: nn.Tanh, input: Tensor, output: Tensor) -> int:
def dmas_tanh(module: nn.Tanh, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Tanh`"""

# Access for both exp
input_dma = input.numel() * 2
output_dma = output.numel()
input_dma = inp.numel() * 2
output_dma = out.numel()

return input_dma + output_dma


def dmas_dropout(module: nn.Dropout, input: Tensor, output: Tensor) -> int:
def dmas_dropout(module: nn.Dropout, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.Dropout`"""

input_dma = input.numel()
input_dma = inp.numel()

# Access sampling probability
ops_dma = 1

output_dma = 0 if module.inplace else output.numel()
output_dma = 0 if module.inplace else out.numel()

return input_dma + ops_dma + output_dma


def dmas_convtransposend(module: _ConvTransposeNd, input: Tensor, output: Tensor) -> int:
def dmas_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.modules.conv._ConvTransposeNd`"""

# Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
@@ -157,29 +157,29 @@ def dmas_convtransposend(module: _ConvTransposeNd, input: Tensor, output: Tensor
out_padding = len(module.kernel_size)

# The rest is like a classic convolution
conv_dmas = dmas_convnd(module, input, output)
conv_dmas = dmas_convnd(module, inp, out)

return in_padding + out_padding + conv_dmas


def dmas_convnd(module: _ConvNd, input: Tensor, output: Tensor) -> int:
def dmas_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.modules.conv._ConvNd`"""

# Each output element required K ** 2 memory access of each input channel
input_dma = module.in_channels * reduce(mul, module.kernel_size) * output.numel()
input_dma = module.in_channels * reduce(mul, module.kernel_size) * out.numel()
# Correct with groups
input_dma //= module.groups

# Access weight & bias
ops_dma = num_params(module)
output_dma = output.numel()
output_dma = out.numel()

return input_dma + ops_dma + output_dma


def dmas_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int:
def dmas_bn(module: _BatchNorm, inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""
input_dma = input.numel()
input_dma = inp.numel()

# Access running_mean, running_var and eps
ops_dma = module.running_mean.numel() + module.running_var.numel() + 1 # type: ignore[union-attr]
@@ -195,39 +195,39 @@ def dmas_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int:
# Update num of batches and running stats
ops_dma += 1 + module.running_mean.numel() + module.running_var.numel() # type: ignore[union-attr]

output_dma = output.numel()
output_dma = out.numel()

return input_dma + ops_dma + output_dma


def dmas_pool(module: Union[_MaxPoolNd, _AvgPoolNd], input: Tensor, output: Tensor) -> int:
def dmas_pool(module: Union[_MaxPoolNd, _AvgPoolNd], inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for spatial pooling modules"""

# Resolve kernel size and stride size (can be stored as a single integer or a tuple)
if isinstance(module.kernel_size, tuple):
kernel_size = module.kernel_size
elif isinstance(module.kernel_size, int):
kernel_size = (module.kernel_size,) * (input.ndim - 2)
kernel_size = (module.kernel_size,) * (inp.ndim - 2)

# Each output element required K ** 2 memory accesses
input_dma = reduce(mul, kernel_size) * output.numel()
input_dma = reduce(mul, kernel_size) * out.numel()

output_dma = output.numel()
output_dma = out.numel()

return input_dma + output_dma


def dmas_adaptive_pool(module: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], input: Tensor, output: Tensor) -> int:
def dmas_adaptive_pool(module: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor) -> int:
"""DMAs estimation for adaptive spatial pooling modules"""

# Approximate kernel_size using ratio of spatial shapes between input and output
kernel_size = tuple(
i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
for i_size, o_size in zip(input.shape[2:], output.shape[2:])
for i_size, o_size in zip(inp.shape[2:], out.shape[2:])
)
# Each output element required K ** 2 memory accesses
input_dma = reduce(mul, kernel_size) * output.numel()
input_dma = reduce(mul, kernel_size) * out.numel()

output_dma = output.numel()
output_dma = out.numel()

return input_dma + output_dma
24 changes: 11 additions & 13 deletions torchscan/modules/receptive.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,13 @@
__all__ = ["module_rf"]


def module_rf(module: Module, input: Tensor, output: Tensor) -> Tuple[float, float, float]:
def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
"""Estimate the spatial receptive field of the module
Args:
module (torch.nn.Module): PyTorch module
input (torch.Tensor): input to the module
output (torch.Tensor): output of the module
inp (torch.Tensor): input to the module
out (torch.Tensor): output of the module
Returns:
receptive field
effective stride
@@ -46,25 +46,23 @@ def module_rf(module: Module, input: Tensor, output: Tensor) -> Tuple[float, flo
):
return 1.0, 1.0, 0.0
elif isinstance(module, _ConvTransposeNd):
return rf_convtransposend(module, input, output)
return rf_convtransposend(module, inp, out)
elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):
return rf_aggregnd(module, input, output)
return rf_aggregnd(module, inp, out)
elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
return rf_adaptive_poolnd(module, input, output)
return rf_adaptive_poolnd(module, inp, out)
else:
warnings.warn(f"Module type not supported: {module.__class__.__name__}")
return 1.0, 1.0, 0.0


def rf_convtransposend(module: _ConvTransposeNd, intput: Tensor, output: Tensor) -> Tuple[float, float, float]:
def rf_convtransposend(module: _ConvTransposeNd, intput: Tensor, out: Tensor) -> Tuple[float, float, float]:
k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
return -k, 1.0 / s, 0.0


def rf_aggregnd(
module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], input: Tensor, output: Tensor
) -> Tuple[float, float, float]:
def rf_aggregnd(module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
if hasattr(module, "dilation"):
d = module.dilation[0] if isinstance(module.dilation, tuple) else module.dilation
@@ -75,11 +73,11 @@ def rf_aggregnd(


def rf_adaptive_poolnd(
module: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], input: Tensor, output: Tensor
module: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor
) -> Tuple[int, int, float]:

stride = math.ceil(input.shape[-1] / output.shape[-1])
stride = math.ceil(inp.shape[-1] / out.shape[-1])
kernel_size = stride
padding = (input.shape[-1] - kernel_size * stride) / 2
padding = (inp.shape[-1] - kernel_size * stride) / 2

return kernel_size, stride, padding

0 comments on commit 4e327c2

Please sign in to comment.